#!/usr/bin/env python3
"""
Experiment runner that automates starting the tokenizer pipeline and sending
packets from a single terminal, repeated for multiple runs, and aggregates
latency metrics across runs.

Examples:
  # 10 runs, each with 20 packets, Rust (fast) GPT-2 using code snippets
  python tests/python/analyze/run_experiments.py \
    --mode rust --model gpt2 --runs 10 --max-packets 20 --code

  # 5 runs, DPDK BPE ModernBERT-base with multilingual sentences
  python tests/python/analyze/run_experiments.py \
    --mode dpdk --tokenizer bpe --model modernbert-base --runs 5 --max-packets 20 --multilingual

  # 3 runs, 10 packets/run, each packet exactly 512 GPT-2 tokens
  python tests/python/analyze/run_experiments.py \
    --mode dpdk --tokenizer bpe --runs 3 --max-packets 10 --tokens-per-packet 512 --openwebtext

Notes:
  - Uses the same underlying pipelines and sender utilities already in the repo.
  - Aggregates per-run stats: mean, P50 (median), P90, P99 into mean +- std across runs.
  - Supports batch flags and TinyBERT encoding similar to run_tokenizer.py.
"""

import argparse
import statistics
import threading
import queue as _queue
import time
import json
import os
import sys
from typing import Optional, Dict, List, Tuple
import numpy as np
try:
    from transformers import AutoConfig  # type: ignore
    _HAS_TRANSFORMERS = True
except Exception:
    _HAS_TRANSFORMERS = False

# Ensure we can import utils and sender helpers
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
REPO_SRC_PY = os.path.join(THIS_DIR, '../../../src/python')
sys.path.append(REPO_SRC_PY)

# Single-core by default: allow overriding via CLI flag detected early
_disable_one_core = "--disable-one-core-limit" in sys.argv
if _disable_one_core:
    os.environ["ONE_CORE_LIMIT_DISABLED"] = "1"
    os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get("TOKENIZERS_PARALLELISM", "true")
    os.environ.pop("RAYON_RS_NUM_THREADS", None)
    os.environ.pop("RAYON_NUM_THREADS", None)

from utils import DPDKPipeline, PythonRustPipeline  # type: ignore

# Import sender function
SENDER_DIR = os.path.join(THIS_DIR, '../send_packets')
sys.path.append(os.path.abspath(SENDER_DIR))
from send_n_packets import send_test_packets  # type: ignore
# Cleanup helpers to avoid hanging processes on port 6000
sys.path.append(os.path.join(THIS_DIR))
try:
    import cleanup_tokenizers  # type: ignore
except Exception:
    cleanup_tokenizers = None


def run_single_experiment(
    mode: str,
    tokenizer: str,
    model: str,
    max_packets: int,
    enable_batch: bool,
    batch_size: int,
    encoder: Optional[str],
    embed_model: Optional[str],
    force_cpu: bool,
    debug: bool,
    debug_bert: bool,
    dataset: str,
    timeout_seconds: int,
    tokens_per_packet: int,
    max_seq_len: int,
    delay_ms: int,
    latency_mode: str,
    pin_core: Optional[int],
    rt_prio: Optional[int],
    dpdk_log_level: Optional[str],
    use_sudo: bool,
    use_msg_id_header: bool,
    allow_non_isolated: bool,
    dataset_offset: int = 0,
    disable_cache: bool = False,
    warmup: bool = False,
    external_pipeline=None,
    keep_pipeline_alive: bool = False,
    clear_rust_cache_before_trial: bool = True,
    embed_packet_delay_ms: int = 0,
    embed_settle_ms: int = 30000,
    embed_missing_tolerance: int = 2,
    per_length_warmup_packets: int = 2,
    dup_chunks_override: int | None = None,
):
    """Run one experiment (pipeline + sender) and return stats dict."""
    # Check GPU requirement for embed model
    if embed_model and embed_model.lower() not in ('standalone', 'none') and not force_cpu:
        try:
            import torch
            if not torch.cuda.is_available():
                print("ERROR: GPU is required for embedding models but CUDA is not available!")
                print("Please run on a machine with GPU or use --cpu flag (not recommended)")
                sys.exit(1)
            gpu_count = torch.cuda.device_count()
            gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "Unknown"
            if debug:
                print(f"GPU Check: Found {gpu_count} GPU(s). Primary: {gpu_name}")
        except ImportError:
            print("WARNING: Cannot verify GPU availability (torch not installed)")
    elif embed_model and force_cpu:
        print("WARNING: Running embedding model on CPU (--cpu flag set). This will be slow!")

    # Pre-run cleanup: kill any lingering tokenizer processes (best-effort)
    if cleanup_tokenizers and external_pipeline is None:
        try:
            pids = cleanup_tokenizers.find_tokenizer_processes()
            if pids and debug:
                print(f"Pre-run cleanup: killing lingering tokenizers {pids}")
                cleanup_tokenizers.kill_processes(pids, force=False)
                time.sleep(0.3)
                # If still present, force kill
                remaining = cleanup_tokenizers.find_tokenizer_processes()
                if remaining and debug:
                    print(f"Pre-run cleanup (force): killing {remaining}")
                    cleanup_tokenizers.kill_processes(remaining, force=True)
        except Exception as e:
            print(f"Warning: cleanup before run failed: {e}")
    # Construct the pipeline according to mode
    # Construct or reuse pipeline
    if external_pipeline is not None:
        pipeline = external_pipeline
        # Reset stats per trial
        from utils.stats import LatencyStats
        pipeline.latency_stats = LatencyStats(is_batch_mode=enable_batch)
        # For fairness when reusing Rust tokenizer, optionally clear cache per trial
        if clear_rust_cache_before_trial and hasattr(pipeline, 'clear_model_cache'):
            pipeline.clear_model_cache()
    elif mode == 'dpdk':
        pipeline = DPDKPipeline(
            tokenizer_type=tokenizer,
            model=model,
            encoder_type=encoder,
            encoder_model=embed_model,
            force_cpu=force_cpu,
            debug_bert=debug_bert,
            debug=debug,
            batch_size=batch_size,
            enable_batch=enable_batch,
            latency_mode=latency_mode,
            pin_core=pin_core,
            rt_prio=rt_prio,
            dpdk_log_level=dpdk_log_level,
            use_sudo=use_sudo,
            allow_non_isolated=allow_non_isolated,
            disable_cache=disable_cache,
        )
    elif mode in ('rust','python','embed'):
        # rust => use_fast=True; python => use_fast=False
        use_fast = True if mode == 'rust' else False
        pipeline = PythonRustPipeline(
            model_name=model,
            encoder_type=encoder,
            encoder_model=embed_model,
            force_cpu=force_cpu,
            debug=debug,
            batch_size=batch_size,
            enable_batch=enable_batch,
            use_fast=use_fast,
            latency_mode=latency_mode,
            warmup=warmup,
            disable_cache=disable_cache,
            pin_core=None,
        )
        if mode == 'embed' and hasattr(pipeline, 'measure_encode_only'):
            pipeline.measure_encode_only = True
    else:
        raise RuntimeError(f"Unknown mode: {mode}")

    # Start pipeline (no fallback). Fail fast on errors with clear message.
    try:
        if external_pipeline is None:
            if debug:
                print(f"[DEBUG] Starting pipeline for mode={mode}, tokens={tokens_per_packet}")
            pipeline.start()
            if debug:
                print(f"[DEBUG] Pipeline started successfully")
    except Exception as e:
        print(f"ERROR: Failed to start pipeline for mode={mode}, tokens={tokens_per_packet}")
        print(str(e))
        # Ensure we don't leave background processes around
        try:
            if external_pipeline is None or not keep_pipeline_alive:
                pipeline.stop()
        except Exception:
            pass
        # Post-failure cleanup
        if cleanup_tokenizers and (external_pipeline is None or not keep_pipeline_alive):
            try:
                pids = cleanup_tokenizers.find_tokenizer_processes()
                if pids:
                    cleanup_tokenizers.kill_processes(pids, force=True)
            except Exception:
                pass
        # Propagate as empty stats so caller can terminate gracefully
        return {"count": 0}

    # Wait for pipeline ready (especially important for DPDK + sudo)
    if external_pipeline is None and hasattr(pipeline, 'wait_ready'):
        ready = pipeline.wait_ready(timeout=30.0)
        if not ready:
            print("ERROR: Tokenizer pipeline did not become ready within 30s.")
            try:
                if external_pipeline is None or not keep_pipeline_alive:
                    pipeline.stop()
            except Exception:
                pass
            return {"count": 0}

    # Optional per-length warmup: send a few packets and drain, then reset stats
    try:
        embed_active = (embed_model and embed_model.lower() not in ('standalone','none'))
        warmup_n = min(max(0, int(per_length_warmup_packets or 0)), max_packets) if embed_active else 0
        if warmup_n > 0:
            # Send warmup packets without gating or sleeps; we'll drain all and reset stats
            # Disable deduplication during warmup so every END produces a result to drain
            try:
                if hasattr(pipeline, '_dedup_enabled'):
                    pipeline._dedup_enabled = False
            except Exception:
                pass
            def _w_on_send(msg_id, t_send):
                pass
            send_test_packets(
                target_port=6000,
                num_packets=warmup_n,
                lowercase=False,
                use_openwebtext=(dataset == 'openwebtext'),
                use_multilingual=(dataset == 'multilingual'),
                use_code=(dataset == 'code'),
                tokens_per_packet=tokens_per_packet,
                max_seq_len=max_seq_len,
                delay_ms=delay_ms,
                dataset_offset=dataset_offset,
                use_msg_id_header=use_msg_id_header,
                dup_chunks=(dup_chunks_override if dup_chunks_override is not None else (1 if use_msg_id_header else 0)),
                on_send=_w_on_send,
                packet_delay_ms=0,
                permit_queue=None,
                verbose=debug,
            )
            # Drain warmup results (best-effort)
            drained = 0; t0 = time.time()
            while drained < warmup_n and (time.time() - t0) < float(timeout_seconds):
                if enable_batch:
                    if hasattr(pipeline, 'get_batch_results'):
                        rs = pipeline.get_batch_results(timeout=1.0) or []
                        drained += len(rs)
                    elif hasattr(pipeline, 'process_batch_packets'):
                        rs = pipeline.process_batch_packets(timeout=1.0) or []
                        drained += len(rs)
                else:
                    if hasattr(pipeline, 'get_result'):
                        r = pipeline.get_result(timeout=1.0)
                    elif hasattr(pipeline, 'process_packet'):
                        r = pipeline.process_packet()
                    else:
                        r = None
                    if r:
                        drained += 1
            # Reset stats after warmup to avoid contaminating measurements
            try:
                from utils.stats import LatencyStats as _LS
                pipeline.latency_stats = _LS(is_batch_mode=enable_batch)
            except Exception:
                pass
            # Reset dedup sets if present
            try:
                if hasattr(pipeline, '_seen_packets'):
                    pipeline._seen_packets = set()
                if hasattr(pipeline, '_dedup_enabled'):
                    pipeline._dedup_enabled = True
            except Exception:
                pass
    except Exception as _we:
        print(f"Warning: warmup phase failed: {_we}")

    # Start sender in a background thread
    sender_done = threading.Event()
    send_times = {}
    # Optional permit queue: enable sequential pacing when embed-model is active
    permit_q = None
    embed_active = (embed_model and embed_model.lower() not in ('standalone','none'))
    if embed_active:
        permit_q = _queue.Queue(max_packets if max_packets > 0 else 0)
        try:
            permit_q.put_nowait(None)  # allow first packet
        except Exception:
            pass

    def _sender():
        def on_send(msg_id, t_send):
            try:
                send_times[int(msg_id)] = float(t_send)
            except Exception:
                pass
        send_test_packets(
            target_port=6000,
            num_packets=max_packets,
            lowercase=False,
            use_openwebtext=(dataset == 'openwebtext'),
            use_multilingual=(dataset == 'multilingual'),
            use_code=(dataset == 'code'),
            tokens_per_packet=tokens_per_packet,
            max_seq_len=max_seq_len,
            delay_ms=delay_ms,
            dataset_offset=dataset_offset,
            use_msg_id_header=use_msg_id_header,
            dup_chunks=(dup_chunks_override if dup_chunks_override is not None else (1 if use_msg_id_header else 0)),
            on_send=on_send,
            # If we have sequential pacing, disable fixed per-packet delay
            packet_delay_ms=(0 if permit_q is not None else (embed_packet_delay_ms if embed_active else 0)),
            permit_queue=permit_q,
            verbose=debug,
        )
        sender_done.set()

    sender_thread = threading.Thread(target=_sender, daemon=True)
    # Small delay to help ensure the pipeline is listening before sending
    time.sleep(0.5)
    sender_thread.start()

    # Process until max_packets or timeout reached
    processed = 0
    start_wait = time.time()
    next_progress = start_wait + 2.0
    e2e_fractions = []  # per-packet fraction for embed-model mode
    e2e_latencies = []  # per-packet end-to-end latencies (us) for embed-model mode
    collected_embeddings = []  # Store embeddings for comparison
    token_counts = []  # Track actual token counts fed to model
    try:
        if enable_batch:
            # Batch processing loop
            while processed < max_packets:
                batch_results = []
                if hasattr(pipeline, 'get_batch_results'):
                    batch_results = pipeline.get_batch_results(timeout=5.0)
                elif hasattr(pipeline, 'process_batch_packets'):
                    batch_results = pipeline.process_batch_packets(timeout=5.0)

                if batch_results:
                    processed += len(batch_results)
                    start_wait = time.time()  # reset wait on activity
                    # Release next permits (sequential pacing)
                    if permit_q is not None:
                        for _ in range(len(batch_results)):
                            try:
                                permit_q.put_nowait(None)
                            except Exception:
                                pass
                    # Collect embeddings and token counts if available
                    for r in batch_results:
                        if 'BERT_EMBEDDING_DATA' in r:
                            collected_embeddings.append(r['BERT_EMBEDDING_DATA'])
                        # Track token count to detect padding differences
                        if 'NUM_TOKENS' in r:
                            token_counts.append(r.get('NUM_TOKENS'))
                        elif 'TOKEN_IDS' in r:
                            token_ids = r.get('TOKEN_IDS', '')
                            if isinstance(token_ids, str):
                                token_counts.append(len(token_ids.split()))
                            elif hasattr(token_ids, '__len__'):
                                token_counts.append(len(token_ids))
                    # Compute E2E fractions for embed-model mode
                    if (embed_model and embed_model.lower() not in ('standalone','none')):
                        for r in batch_results:
                            msg_id = r.get('MESSAGE_ID')
                            if msg_id is None:
                                continue
                            t_send = send_times.get(int(msg_id))
                            if not t_send:
                                continue
                            t_recv = r.get('DELIVERY_TIME', time.perf_counter())
                            e2e_us = (t_recv - t_send) * 1_000_000.0
                            if e2e_us <= 0:
                                continue
                            # Prefer DPDK precise tokenization when available
                            if 'TOKENIZE_START_TIME' in r and 'TOKENIZE_END_TIME' in r and 'TSC_FREQUENCY' in r:
                                cycles = r['TOKENIZE_END_TIME'] - r['TOKENIZE_START_TIME']
                                tsc = r['TSC_FREQUENCY'] or 1
                                tok_us = (cycles * 1_000_000.0) / tsc
                            else:
                                tok_us = r.get('TOKENIZE_LATENCY_US') or r.get('TOKENIZE_US') or 0
                            if tok_us and e2e_us:
                                e2e_fractions.append(max(0.0, min(1.0, tok_us / e2e_us)))
                                e2e_latencies.append(e2e_us)
                else:
                    # After sender finished, either keep waiting for embeddings or settle after a window.
                    if sender_done.is_set():
                        embed_active = (embed_model and embed_model.lower() not in ('standalone','none'))
                        if embed_active:
                            missing = max(0, max_packets - processed)
                            settle_s = max(0.0, (embed_settle_ms or 0) / 1000.0)
                            if settle_s > 0 and (time.time() - start_wait) > settle_s and missing <= (embed_missing_tolerance or 0):
                                if debug:
                                    print(f"Settle window reached; allowing missing={missing} (tolerance={embed_missing_tolerance}).")
                                break
                            # Otherwise use a large drain cap to avoid premature exit
                            drain_limit = 3600.0
                        else:
                            drain_limit = float(timeout_seconds)
                    else:
                        drain_limit = float(timeout_seconds)
                    if time.time() - start_wait > drain_limit:
                        print(f"Timeout waiting for packets (>{int(drain_limit)}s).")
                        break
                # periodic progress
                now = time.time()
                if now >= next_progress:
                    if debug:
                        print(f"...progress: {processed}/{max_packets} packets processed")
                    next_progress = now + 2.0
        else:
            # Per-packet loop
            consecutive_timeouts = 0
            max_consecutive_timeouts = 3
            while processed < max_packets:
                result = None
                if hasattr(pipeline, 'get_result'):
                    # DPDK pipeline exposes results from its monitor thread
                    result = pipeline.get_result(timeout=5.0)
                    if result is None:
                        consecutive_timeouts += 1
                        # Check if DPDK process is still alive
                        if hasattr(pipeline, 'process') and pipeline.process:
                            if pipeline.process.poll() is not None:
                                print(f"ERROR: DPDK process died with code {pipeline.process.returncode}")
                                break
                        if consecutive_timeouts >= max_consecutive_timeouts:
                            print(f"WARNING: No response after {consecutive_timeouts} attempts. Allowing {max_packets - processed} missing packets.")
                            # For experiments, better to continue with missing data than hang forever
                            break
                elif hasattr(pipeline, 'process_packet'):
                    # HF tokenizer pipeline actively processes from socket
                    result = pipeline.process_packet()
                    if result is None:
                        consecutive_timeouts += 1
                        if consecutive_timeouts >= max_consecutive_timeouts:
                            # Release permit to retry sending
                            if permit_q is not None:
                                try:
                                    permit_q.put_nowait(None)
                                    print(f"[RETRY] No response after {consecutive_timeouts} timeouts, releasing permit to retry packet")
                                    consecutive_timeouts = 0
                                except Exception:
                                    pass

                if result:
                    consecutive_timeouts = 0  # Reset timeout counter on successful packet
                    processed += 1
                    start_wait = time.time()
                    if processed == 1:
                        if debug:
                            print(f"[DEBUG] First result received for mode={mode}, tokens={tokens_per_packet}")
                    # Collect embeddings and token counts if available
                    if 'BERT_EMBEDDING_DATA' in result:
                        collected_embeddings.append(result['BERT_EMBEDDING_DATA'])
                    # Track token count to detect padding differences
                    if 'NUM_TOKENS' in result:
                        token_counts.append(result.get('NUM_TOKENS'))
                    elif 'TOKEN_IDS' in result:
                        token_ids = result.get('TOKEN_IDS', '')
                        if isinstance(token_ids, str):
                            token_counts.append(len(token_ids.split()))
                        elif hasattr(token_ids, '__len__'):
                            token_counts.append(len(token_ids))
                    # Release next permit (sequential pacing)
                    if permit_q is not None:
                        try:
                            permit_q.put_nowait(None)
                        except Exception:
                            pass
                    # Compute E2E fraction for embed-model mode
                    if (embed_model and embed_model.lower() not in ('standalone','none')):
                        msg_id = result.get('MESSAGE_ID')
                        if msg_id is not None:
                            t_send = send_times.get(int(msg_id))
                            if t_send:
                                t_recv = result.get('DELIVERY_TIME', time.perf_counter())
                                e2e_us = (t_recv - t_send) * 1_000_000.0
                                if e2e_us > 0:
                                    if 'TOKENIZE_START_TIME' in result and 'TOKENIZE_END_TIME' in result and 'TSC_FREQUENCY' in result:
                                        cycles = result['TOKENIZE_END_TIME'] - result['TOKENIZE_START_TIME']
                                        tsc = result['TSC_FREQUENCY'] or 1
                                        tok_us = (cycles * 1_000_000.0) / tsc
                                    else:
                                        tok_us = result.get('TOKENIZE_LATENCY_US') or result.get('TOKENIZE_US') or 0
                                    if tok_us:
                                        e2e_fractions.append(max(0.0, min(1.0, tok_us / e2e_us)))
                                        e2e_latencies.append(e2e_us)
                                    else:
                                        e2e_latencies.append(e2e_us)
                                    # Optional verbose per-packet debug
                                    if debug:
                                        try:
                                            num_tokens = result.get('NUM_TOKENS')
                                            enc_us = result.get('BERT_ENCODE_TIME_US')
                                            print(f"[E2E] mode={mode} tokens_pp={tokens_per_packet} msg={msg_id} e2e_us={e2e_us:.2f} tok_us={float(tok_us) if tok_us else 0:.2f} enc_us={float(enc_us) if enc_us else 0:.2f} num_tokens={num_tokens}")
                                        except Exception:
                                            pass
                else:
                    if sender_done.is_set():
                        embed_active = (embed_model and embed_model.lower() not in ('standalone','none'))
                        if embed_active:
                            missing = max(0, max_packets - processed)
                            settle_s = max(0.0, (embed_settle_ms or 0) / 1000.0)
                            if settle_s > 0 and (time.time() - start_wait) > settle_s and missing <= (embed_missing_tolerance or 0):
                                print(f"Settle window reached; allowing missing={missing} (tolerance={embed_missing_tolerance}).")
                                break
                            drain_limit = 3600.0
                        else:
                            drain_limit = float(timeout_seconds)
                    else:
                        drain_limit = float(timeout_seconds)
                    if time.time() - start_wait > drain_limit:
                        print(f"Timeout waiting for packets (>{int(drain_limit)}s).")
                        break
                # periodic progress
                now = time.time()
                if now >= next_progress:
                    print(f"...progress: {processed}/{max_packets} packets processed")
                    next_progress = now + 2.0
    finally:
        try:
            sender_thread.join()
        except Exception:
            pass
        if external_pipeline is None or not keep_pipeline_alive:
            pipeline.stop()
            # Post-run cleanup (best-effort)
            if cleanup_tokenizers:
                try:
                    pids = cleanup_tokenizers.find_tokenizer_processes()
                    if pids:
                        cleanup_tokenizers.kill_processes(pids, force=True)
                except Exception:
                    pass

    stats = pipeline.latency_stats.stats()
    if debug:
        print(f"[DEBUG] Final stats for mode={mode}, tokens={tokens_per_packet}: count={stats.get('count')}, mean={stats.get('mean')}, processed={processed}")

    # Report token count statistics to detect padding differences
    if token_counts and embed_model and embed_model.lower() not in ('standalone', 'none'):
        import statistics as _st
        avg_tokens = _st.mean(token_counts) if token_counts else 0
        min_tokens = min(token_counts) if token_counts else 0
        max_tokens = max(token_counts) if token_counts else 0
        std_tokens = _st.stdev(token_counts) if len(token_counts) > 1 else 0
        print(f"[TOKEN STATS] Mode={mode}: avg={avg_tokens:.1f}, min={min_tokens}, max={max_tokens}, std={std_tokens:.2f}")

        # Warn if there's significant variation (potential padding issue)
        if max_tokens - min_tokens > 50:
            print(f"WARNING: Large token count variation detected! This may indicate padding differences.")
            print(f"         Expected ~{tokens_per_packet} tokens per packet")
    # Attach E2E fraction mean for embed-model mode
    try:
        if (embed_model and embed_model.lower() not in ('standalone','none')) and e2e_fractions:
            import statistics as _st
            stats['e2e_fraction_mean'] = _st.mean(e2e_fractions)
            stats['e2e_fraction_count'] = len(e2e_fractions)
            # Override primary latency stats with true E2E metrics for fairness
            # BUT: Only override for non-embed modes. For embed mode, we want to keep the encode-only timing
            if e2e_latencies and mode != 'embed':
                e2e = e2e_latencies
                e2e_mean = _st.mean(e2e)
                e2e_median = _st.median(e2e)
                # Compute p90/p99 using quantiles if enough samples, else fallback to max
                try:
                    p90 = _st.quantiles(e2e, n=10)[8] if len(e2e) >= 10 else max(e2e)
                except Exception:
                    p90 = max(e2e)
                try:
                    p99 = _st.quantiles(e2e, n=100)[98] if len(e2e) >= 10 else max(e2e)
                except Exception:
                    p99 = max(e2e)
                stats['mean'] = e2e_mean
                stats['median'] = e2e_median
                stats['p90'] = p90
                stats['p99'] = p99
    except Exception:
        pass
    # Attach cache metrics for DPDK if available (will be absent for Rust/Python)
    try:
        if hasattr(pipeline, 'cache_lookups_total'):
            stats['cache_lookups'] = int(pipeline.cache_lookups_total)
        if hasattr(pipeline, 'cache_hits_total'):
            stats['cache_hits'] = int(pipeline.cache_hits_total)
        if hasattr(pipeline, 'cache_inserts_total'):
            stats['cache_inserts'] = int(getattr(pipeline, 'cache_inserts_total', 0))
        if hasattr(pipeline, 'cache_insert_fails_total'):
            stats['cache_insert_fails'] = int(getattr(pipeline, 'cache_insert_fails_total', 0))
        if hasattr(pipeline, 'cache_skip_longkey_total'):
            stats['cache_skip_longkey'] = int(getattr(pipeline, 'cache_skip_longkey_total', 0))
        if hasattr(pipeline, 'cache_skip_oversize_total'):
            stats['cache_skip_oversize'] = int(getattr(pipeline, 'cache_skip_oversize_total', 0))
        if stats.get('cache_lookups', 0) > 0:
            stats['cache_hit_ratio'] = stats['cache_hits'] / stats['cache_lookups']
        else:
            stats['cache_hit_ratio'] = None
    except Exception:
        pass

    # Include collected embeddings in the return
    stats['embeddings'] = collected_embeddings
    return stats


def aggregate_runs(run_stats: list[dict]) -> dict:
    """Aggregate mean/median/p90/p99 across runs into mean +- std per metric."""
    # Extract per-run metrics (in microseconds)
    means = [s.get('mean') for s in run_stats if s.get('count', 0) > 0]
    medians = [s.get('median') for s in run_stats if s.get('count', 0) > 0]
    p90s = [s.get('p90') for s in run_stats if s.get('count', 0) > 0]
    p99s = [s.get('p99') for s in run_stats if s.get('count', 0) > 0]

    def _agg(values: list[float]) -> dict:
        if not values:
            return {"mean": None, "std": None}
        return {
            "mean": statistics.mean(values),
            "std": statistics.stdev(values) if len(values) > 1 else 0.0,
        }

    return {
        "runs": len(run_stats),
        "per_run_counts": [s.get('count', 0) for s in run_stats],
        "mean_us": _agg(means),
        "p50_us": _agg(medians),
        "p90_us": _agg(p90s),
        "p99_us": _agg(p99s),
    }


def fmt_us(value_us: float | None) -> str:
    if value_us is None:
        return "n/a"
    if value_us >= 1000.0:
        return f"{value_us/1000.0:.2f} ms"
    return f"{value_us:.2f} us"


def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
    """Calculate cosine similarity between two vectors."""
    vec1_flat = vec1.flatten()
    vec2_flat = vec2.flatten()

    dot_product = np.dot(vec1_flat, vec2_flat)
    norm1 = np.linalg.norm(vec1_flat)
    norm2 = np.linalg.norm(vec2_flat)

    if norm1 == 0 or norm2 == 0:
        return 0.0

    return float(dot_product / (norm1 * norm2))


def compare_embeddings(embeddings_dict: Dict[str, List[np.ndarray]], output_path: str) -> Dict[Tuple[str, str], List[float]]:
    """Compare embeddings across different modes and save results."""
    modes = list(embeddings_dict.keys())
    comparison_results = {}

    # Compare each pair of modes
    for i, mode1 in enumerate(modes):
        for mode2 in modes[i+1:]:
            similarities = []
            emb1_list = embeddings_dict[mode1]
            emb2_list = embeddings_dict[mode2]

            # Compare corresponding packets
            min_len = min(len(emb1_list), len(emb2_list))
            for idx in range(min_len):
                sim = cosine_similarity(emb1_list[idx], emb2_list[idx])
                similarities.append(sim)

            comparison_results[(mode1, mode2)] = similarities

    # Save detailed comparison to CSV
    try:
        with open(output_path, 'w') as f:
            # Write header
            headers = ['packet_id']
            for mode_pair in comparison_results.keys():
                headers.append(f"{mode_pair[0]}_vs_{mode_pair[1]}")
            f.write(','.join(headers) + '\n')

            # Write similarity scores for each packet
            max_packets = max(len(sims) for sims in comparison_results.values())
            for packet_id in range(max_packets):
                row = [str(packet_id + 1)]
                for mode_pair, sims in comparison_results.items():
                    if packet_id < len(sims):
                        row.append(f"{sims[packet_id]:.6f}")
                    else:
                        row.append('')
                f.write(','.join(row) + '\n')

        print(f"Saved embedding comparison to {output_path}")
    except Exception as e:
        print(f"Warning: Failed to save embedding comparison: {e}")

    return comparison_results


def main():
    parser = argparse.ArgumentParser(description="Automated experiment runner for tokenizer pipelines")
    parser.add_argument("--mode", choices=["dpdk", "rust", "python", "embed", "all"], nargs='+', default=None,
                        help="One or more modes to test (space-separated). Use 'all' for dpdk rust python")
    parser.add_argument("--model", default="gpt2")
    parser.add_argument("--tokenizer", choices=["simple", "wordpiece", "bpe"], default="bpe")
    parser.add_argument("--runs", type=int, default=1, help="Number of experiments to run")
    parser.add_argument("--max-packets", type=int, default=1, help="Packets per run")
    # Batching: if --batch-size > 1, batching is enabled automatically.
    parser.add_argument("--enable-batch", action="store_true",
                        help="[Deprecated] Batching is enabled automatically when --batch-size > 1")
    parser.add_argument("--batch-size", type=int, default=1,
                        help="Batch size (default: 1 = no batching). If > 1, enables batching.")
    parser.add_argument("--encoder", choices=["tinybert"], default=None, help="Legacy encoder shortcut (e.g., tinybert). Prefer --embed-model for HF model name.")
    parser.add_argument("--embed-model", default="standalone", help="HF model name for embedding (e.g., 'answerdotai/ModernBERT-base', 'intfloat/e5-base-v2', or 'standalone' to disable)")
    parser.add_argument("--embed-packet-delay-ms", type=int, default=100,
                        help="Inter-packet delay (ms) to apply when --embed-model is set (default: 100)")
    parser.add_argument("--compare-embeddings", action="store_true",
                        help="Enable embedding comparison across modes (uses extra memory)")
    parser.add_argument("--cpu", action="store_true", help="Force CPU for encoder")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--debug-bert", action="store_true")
    parser.add_argument("--timeout", type=int, default=60, help="Per-run inactivity timeout seconds")
    parser.add_argument("--latency-mode", choices=["end2end", "tokenize-only"], default="tokenize-only",
                        help="Latency aggregation: tokenize-only (default, algorithm only) or end2end")
    parser.add_argument("--disable-cache", action="store_true",
                        help="Disable tokenizer caches (Rust fast/Python and DPDK BPE)")
    # UDP sender tuning
    parser.add_argument("--dup-chunks", type=int, default=None,
                        help="Duplicate each UDP chunk this many extra times (default: 1 when msg-id header is used; 0 otherwise)")
    # Warmup: enabled by default unless --no-warmup is set
    parser.add_argument("--warmup", action="store_true",
                        help="Warm up tokenizers/models before measurement (default: enabled)")
    parser.add_argument("--no-warmup", action="store_true",
                        help="Disable warmup (override default)")
    parser.add_argument("--disable-rust-cache-clearing-per-trial", action="store_true",
                        help="Keep Rust fast tokenizer cache across trials (unfair vs. DPDK/Python; use for sensitivity analysis)")
    # DPDK runtime tuning shortcuts
    parser.add_argument("--pin-core", type=int, default=None,
                        help="Pin DPDK tokenizer process to a CPU core (sets DPDK_PIN_CORE for child)")
    parser.add_argument("--rt-prio", type=int, default=None,
                        help="Run tokenizer with SCHED_FIFO at priority (sets DPDK_RT_PRIO; needs sudo/CAP_SYS_NICE)")
    parser.add_argument("--dpdk-log-level", default=None,
                        help="DPDK EAL log level (0..8) for child (overrides DPDK_LOG_LEVEL)")
    parser.add_argument("--disable-sudo", action="store_true",
                        help="Do NOT use sudo to launch the DPDK tokenizer (default is sudo for best perf)")
    parser.add_argument("--no-msg-id-header", action="store_true",
                        help="Disable extended UDP header (default: enabled)")
    parser.add_argument("--allow-non-isolated", action="store_true",
                        help="Allow running when pin core is not in isolcpus (reduced determinism)")
    parser.add_argument("--disable-one-core-limit", action="store_true",
                        help="Allow tokenizers to use multiple threads (default is single-core)")
    # Token-controlled sending
    parser.add_argument("--tokens-per-packet", type=int, nargs='+', default=None,
                        help="Global default token counts per packet (e.g., 128 256 512)")
    # Per-model overrides for token counts
    parser.add_argument("--tpp-override", action='append', default=None,
                        help="Per-model override in the form 'model_id=t1,t2,...'. May be repeated. Use '*' for wildcard.")
    parser.add_argument("--tpp-config", default=None,
                        help="Optional YAML/JSON file with per-model tokens-per-packet configuration.")
    parser.add_argument("--print-plan", action='store_true',
                        help="Print the resolved tokens-per-packet plan per model before running")
    parser.add_argument("--max-seq-len", type=int, default=2048,
                        help="Receiver's maximum supported tokens (default: 2048)")
    parser.add_argument("--delay-ms", type=int, default=0,
                        help="Optional inter-chunk delay in ms for sender pacing")
    # Dataset selection
    parser.add_argument("--dataset", choices=["simple", "openwebtext", "multilingual", "code", "all"], nargs='+', default=None,
                        help="Datasets to test (space-separated). Use 'all' for openwebtext multilingual code")
    # Test output organization
    parser.add_argument("--test-name", default='last_test',
                        help="Name of the test run; results saved under tests/python/analyze/results/<test-name>/")
    parser.add_argument("--override", action="store_true",
                        help="Allow overwriting an existing test folder without prompting")
    parser.add_argument("--out-file", default=None,
                        help="CSV output path for aggregated results (defaults to results/<test-name>/results_summary.csv)")
    # Plot-only mode
    parser.add_argument("--plot-only", action="store_true", help="Skip runs and plot from --out-file CSV")
    parser.add_argument("--plot-metric", choices=["avg", "p50", "p90", "p99"], default="avg",
                        help="Metric group to plot when --plot-only is used")
    parser.add_argument("--plot-dataset", nargs='*', default=None,
                        help="Datasets to plot (default: all in CSV) when --plot-only is used")
    parser.add_argument("--plot-out-dir", default=None,
                        help="Directory to save plots (default: results/<test-name>/)")
    # Optional JSON save
    parser.add_argument("--save-json", default=None, help="Path to save per-run and aggregate stats JSON")

    args = parser.parse_args()

    # Fast defaults: pick a sane core, RT prio, and log level when not provided
    cpu_count = os.cpu_count() or 1
    default_core = 0
    effective_pin_core = args.pin_core if args.pin_core is not None else default_core
    # Use higher RT priority by default for lower jitter
    effective_rt_prio = args.rt_prio if args.rt_prio is not None else 80
    effective_dpdk_log = args.dpdk_log_level if args.dpdk_log_level is not None else "3"
    # Derive effective warmup and batching behavior
    effective_warmup = (not getattr(args, 'no_warmup', False)) or bool(getattr(args, 'warmup', False))
    effective_enable_batch = bool(getattr(args, 'enable_batch', False)) or (getattr(args, 'batch_size', 1) and args.batch_size > 1)

    # Establish results directory and derived defaults
    results_root = os.path.join(THIS_DIR, 'results')
    test_dir = os.path.join(results_root, args.test_name)
    # Add suffix to all new-style results when an embedding model is used
    # For plot-only mode, don't add suffix yet - will be handled per model
    if args.plot_only:
        model_suffix = ""
    else:
        model_suffix = "" if (not args.embed_model or args.embed_model.lower() in ("standalone", "none")) else ("_" + args.embed_model.replace('/', '_'))
    # Default output paths if not provided
    if not args.out_file:
        args.out_file = os.path.join(test_dir, f'results_summary{model_suffix}.csv')
    if not getattr(args, 'plot_out_dir', None):
        args.plot_out_dir = test_dir

    # Build embed-model list (support --embed-model all)
    def _sanitize_suffix(name: str) -> str:
        return name.replace('/', '_').replace(':', '_')
    if args.embed_model and args.embed_model.lower() == 'all':
        embed_models = [
            'google/embeddinggemma-300m',
            'intfloat/e5-base-v2',
            'answerdotai/ModernBERT-base',
        ]
    else:
        embed_models = [args.embed_model or 'standalone']

    # Parse per-model TPP overrides and config
    def _parse_tpp_str_to_list(val: str) -> list[int]:
        toks = []
        for part in val.replace(',', ' ').split():
            try:
                num = int(float(part))
                if num > 0:
                    toks.append(num)
            except Exception:
                continue
        # Deduplicate while preserving order
        seen = set()
        out = []
        for t in toks:
            if t not in seen:
                out.append(t)
                seen.add(t)
        return out

    tpp_overrides_map: dict[str, list[int]] = {}
    tpp_default_from_cfg: list[int] | None = None
    # From --tpp-override repeated flags
    if args.tpp_override:
        for item in args.tpp_override:
            try:
                if '=' not in item:
                    continue
                key, vals = item.split('=', 1)
                key = key.strip()
                arr = _parse_tpp_str_to_list(vals)
                if key and arr:
                    tpp_overrides_map[key] = arr
            except Exception:
                pass
    # From --tpp-config file (YAML or JSON)
    if args.tpp_config:
        try:
            cfg_path = args.tpp_config
            data = None
            if cfg_path.lower().endswith(('.yml', '.yaml')):
                try:
                    import yaml  # type: ignore
                    with open(cfg_path, 'r') as f:
                        data = yaml.safe_load(f)
                except Exception:
                    data = None
            if data is None:
                import json
                with open(cfg_path, 'r') as f:
                    data = json.load(f)
            if isinstance(data, dict):
                # Support keys: default, overrides
                if 'default' in data and data['default'] is not None:
                    try:
                        tpp_default_from_cfg = _parse_tpp_str_to_list(' '.join(map(str, data['default']))) if isinstance(data['default'], (list, tuple)) else _parse_tpp_str_to_list(str(data['default']))
                    except Exception:
                        tpp_default_from_cfg = None
                ov = data.get('overrides') if isinstance(data.get('overrides'), dict) else data
                for k, v in (ov or {}).items():
                    try:
                        arr = _parse_tpp_str_to_list(' '.join(map(str, v)) if isinstance(v, (list, tuple)) else str(v))
                        if arr:
                            tpp_overrides_map[str(k)] = arr
                    except Exception:
                        continue
        except Exception as _cfg_ex:
            print(f"Warning: failed to parse --tpp-config: {args.tpp_config}: {_cfg_ex}")

    # Optional plan printout
    if args.print_plan:
        print("Tokens-per-packet plan (preliminary):")
        # just show available info; the final filtering by model caps happens later per model
        if args.tokens_per_packet:
            print(f"  Default (--tokens-per-packet): {', '.join(map(str, args.tokens_per_packet))}")
        if tpp_default_from_cfg:
            print(f"  Default from --tpp-config: {', '.join(map(str, tpp_default_from_cfg))}")
        for k, v in tpp_overrides_map.items():
            print(f"  Override {k}: {', '.join(map(str, v))}")

    # Helper: embed-model grids (latency and throughput) per dataset
    def _plot_latency_grid_by_model(model_rows_map: dict[str, list[dict]], dataset: str, out_dir: str, outfile: str, log_y: bool = False):
        try:
            import matplotlib.pyplot as plt
            from matplotlib.ticker import MaxNLocator, FuncFormatter, LogLocator, LogFormatterMathtext
            # Build latency series for a given metric
            def _build_series(rows: list[dict], metric: str) -> dict[str, list[tuple[int, float, float]]]:
                col = {
                    "pintok": (f"{metric}_pintok_mean_us", f"{metric}_pintok_std_us"),
                    "rust": (f"{metric}_rust_mean_us", f"{metric}_rust_std_us"),
                    "python": (f"{metric}_python_mean_us", f"{metric}_python_std_us"),
                    "embed": (f"{metric}_embed_mean_us", f"{metric}_embed_std_us"),
                }
                series = {"pintok": [], "rust": [], "python": [], "embed": []}
                for r in rows:
                    if r.get('dataset') != dataset:
                        continue
                    try:
                        tokens = int(float(r.get('tokens_per_packet', '')))
                    except Exception:
                        continue
                    for mode, (mcol, scol) in col.items():
                        try:
                            mean = float(r.get(mcol)) if r.get(mcol) else None
                            std = float(r.get(scol)) if r.get(scol) else None
                        except Exception:
                            mean, std = None, None
                        if mean is not None:
                            series[mode].append((tokens, mean, std or 0.0))
                for mode in list(series.keys()):
                    series[mode].sort(key=lambda t: t[0])
                return series

            COLORS = {"python": "#FDB515", "rust": "#002676", "pintok": "#8C1515", "embed": "#555555"}
            LABELS = {"python": "Python", "rust": "Rust", "pintok": "PinTok", "embed": "ModelOnly"}
            METRICS = ["avg", "p50", "p90", "p99"]
            METRIC_TITLES = {"avg": "Average", "p50": "P50", "p90": "P90", "p99": "P99"}

            model_names = [m for m, rows in model_rows_map.items() if any(r.get('dataset') == dataset for r in rows)]
            if not model_names:
                return False

            plt.rcParams.update({
                'font.size': 6.5,
                'axes.titlesize': 6.5,
                'axes.labelsize': 6.0,
                'xtick.labelsize': 4.0,
                'ytick.labelsize': 5.0,
                'legend.fontsize': 5.5,
                'axes.linewidth': 0.6,
                'xtick.major.width': 0.5,
                'ytick.major.width': 0.5,
                'xtick.major.size': 2.0,
                'ytick.major.size': 2.0,
                'grid.linewidth': 0.4,
            })

            nrows = len(model_names)
            ncols = len(METRICS)
            # Make the overall figure more compact by reducing height ~20%
            _fig_height = (1.1 + 1.1 * nrows) * 0.8
            fig, axs = plt.subplots(nrows, ncols, figsize=(4.4, _fig_height))
            if nrows == 1:
                axs = [axs]

            def _plot_cell(ax, series: dict[str, list[tuple[int, float, float]]]):
                any_points = False
                all_x_values = []
                for mode in ["python", "rust", "pintok", "embed"]:
                    pts = series.get(mode, [])
                    if not pts:
                        continue
                    any_points = True
                    xs = [t for (t, _, _) in pts]
                    ys = [m for (_, m, _) in pts]
                    es = [s for (_, _, s) in pts]
                    all_x_values.extend(xs)
                    if len(xs) == 1:
                        ax.scatter(xs, ys, color=COLORS[mode], label=LABELS[mode], marker='o', s=12, linewidths=0.0)
                    else:
                        ax.plot(xs, ys, color=COLORS[mode], label=LABELS[mode], marker='o', linewidth=1.2, markersize=2.5)
                        # Fix for numpy array comparison
                        has_errors = False
                        try:
                            has_errors = any(float(e) > 0 for e in es)
                        except (TypeError, ValueError):
                            has_errors = False
                        if (not log_y) and has_errors:
                            lower = [y - e for y, e in zip(ys, es)]
                            upper = [y + e for y, e in zip(ys, es)]
                            ax.fill_between(xs, lower, upper, color=COLORS[mode], alpha=0.15, linewidth=0)
                # Draw grid for both major and minor ticks to ensure alignment,
                # especially important on log-y where labeled ticks may be minors.
                ax.grid(True, which='both', axis='both', linestyle='--', alpha=0.25, linewidth=0.4)

                # Dynamic x-axis limits based on data with padding
                if all_x_values:
                    x_min, x_max = min(all_x_values), max(all_x_values)
                    x_range = x_max - x_min if x_max > x_min else 100
                    x_padding = x_range * 0.1  # Add 10% of range as padding on each side
                    ax.set_xlim(max(0, x_min - x_padding), x_max + x_padding)

                    # Dynamic x-ticks based on data range
                    if x_max - x_min <= 600:
                        # Small range: show every 128 or 256 tokens
                        tick_step = 128 if x_max - x_min <= 300 else 256
                    else:
                        # Larger range: show every 256 or 512 tokens
                        tick_step = 256 if x_max - x_min <= 1000 else 512
                    x_ticks = list(range(0, int(x_max + x_padding) + 1, tick_step))
                    # Ensure we include the actual data points if they're round numbers
                    for x in all_x_values:
                        if x % 256 == 0 and x not in x_ticks:
                            x_ticks.append(x)
                    x_ticks = sorted([t for t in x_ticks if t >= max(0, x_min - x_padding) and t <= x_max + x_padding])
                    ax.set_xticks(x_ticks)
                else:
                    # Fallback to default if no data
                    ax.set_xlim(0, 2100)
                    ax.set_xticks([500, 1000, 1500, 2000])
                ax.xaxis.set_major_formatter(FuncFormatter(lambda x, p: f"{int(x):,}"))
                if log_y:
                    ax.set_yscale('log')
                else:
                    # Locator will be finalized after we set per-row y-lims
                    ax.yaxis.set_major_locator(MaxNLocator(nbins=4))
            # Precompute per-row y-limits to align subplots within the same row
            import math
            series_by_row_col: dict[int, dict[str, dict[str, list[tuple[int, float, float]]]] ] = {}
            row_y_lims: dict[int, tuple[float, float]] = {}
            for r, model in enumerate(model_names):
                mrows = model_rows_map.get(model, [])
                series_by_row_col[r] = {}
                y_min = math.inf
                y_max = -math.inf
                for metric in METRICS:
                    s = _build_series(mrows, metric)
                    series_by_row_col[r][metric] = s
                    for mode in ["python", "rust", "pintok", "embed"]:
                        for (_, mean, std) in s.get(mode, []):
                            if log_y:
                                # No shading/error region on log plots; include the mean only
                                if mean is None:
                                    continue
                                y_min = min(y_min, float(mean))
                                y_max = max(y_max, float(mean))
                            else:
                                e = float(std or 0.0)
                                y_min = min(y_min, float(mean) - e)
                                y_max = max(y_max, float(mean) + e)
                if y_min is math.inf or y_max is -math.inf:
                    # No data in this row; fallback defaults
                    y_min, y_max = (0.0, 1.0)
                if not log_y:
                    y_min = max(0.0, y_min)
                    span = max(1e-9, y_max - y_min)
                    pad = 0.05 * span
                    row_y_lims[r] = (max(0.0, y_min - pad), y_max + pad)
                else:
                    # Ensure positive bounds on log scale and add multiplicative padding
                    y_min = max(y_min, 1e-12)
                    factor = 1.15
                    row_y_lims[r] = (y_min / factor, y_max * factor)

            # Plot using precomputed series and apply aligned y-limits per row
            for r, model in enumerate(model_names):
                for c, metric in enumerate(METRICS):
                    ax = axs[r][c]
                    series = series_by_row_col[r][metric]
                    _plot_cell(ax, series)
                    # Titles/labels
                    if r == 0:
                        ax.set_title(METRIC_TITLES[metric])
                    if c == 0:
                        # Show shorter model name (drop any prefix like 'org/')
                        try:
                            _mname = str(model).split('/')[-1]
                        except Exception:
                            _mname = str(model)
                        ax.set_ylabel(f"{_mname}\nLatency (us)")
                        ax.tick_params(axis='y', which='both', labelleft=True)
                    else:
                        ax.tick_params(axis='y', which='both', labelleft=False)
                    # Show x tick labels on all rows; keep x-axis label on bottom row
                    if r == nrows - 1:
                        ax.set_xlabel("Tokens")
                    ax.tick_params(axis='x', which='both', labelbottom=True)
                    # Apply aligned y-limits and consistent locators per row
                    ymin, ymax = row_y_lims.get(r, (None, None))
                    if ymin is not None and ymax is not None:
                        ax.set_ylim(ymin, ymax)
                        if log_y:
                            ax.set_yscale('log')
                            ax.yaxis.set_major_locator(LogLocator(base=10.0))
                            ax.yaxis.set_major_formatter(LogFormatterMathtext())
                            ax.minorticks_on()
                        else:
                            ax.yaxis.set_major_locator(MaxNLocator(nbins=4))
                        # Re-apply grid so it aligns with the final tick locators
                        ax.grid(True, which='both', axis='both', linestyle='--', alpha=0.25, linewidth=0.4)

            handles = []
            labels = []
            for r in range(nrows):
                for c in range(ncols):
                    h, l = axs[r][c].get_legend_handles_labels()
                    if h and l:
                        handles, labels = h, l
                        break
                if handles and labels:
                    break

            # Ensure legend items are in the correct order: Python, Rust, PinTok
            ordered_handles = []
            ordered_labels = []
            for desired_label in ["Python", "Rust", "PinTok"]:
                if desired_label in labels:
                    idx = labels.index(desired_label)
                    ordered_handles.append(handles[idx])
                    ordered_labels.append(labels[idx])

            fig.tight_layout(rect=[0.006, 0.01, 0.88, 1.0])
            if ordered_handles and ordered_labels:
                fig.legend(
                    ordered_handles,
                    ordered_labels,
                    loc='center left',
                    ncol=1,
                    frameon=False,
                    bbox_to_anchor=(0.89, 0.5),
                    borderaxespad=0.0,
                    handlelength=1.5,
                    handletextpad=0.5,
                    markerscale=1.0,
                    borderpad=0.2,
                )
            os.makedirs(out_dir, exist_ok=True)
            out_path = os.path.join(out_dir, outfile)
            fig.savefig(out_path, bbox_inches='tight', pad_inches=0.02)
            plt.close(fig)
            print(f"Saved by-model latency grid: {out_path}")
            return True
        except Exception as e:
            print(f"Warning: failed to render by-model latency grid for {dataset}: {e}")
            return False

    def _plot_throughput_grid_by_model(model_rows_map: dict[str, list[dict]], dataset: str, out_dir: str, outfile: str):
        try:
            import matplotlib.pyplot as plt
            import traceback
            # Build throughput series from throughput_summary.csv-like rows
            def _build_series(rows: list[dict]) -> dict[str, list[tuple[int, float, float]]]:
                series = {"pintok": [], "rust": [], "python": []}
                has_tps = any('pintok_tps_mean' in r for r in rows) or any('pintok_tps_mean' in (k or '') for r in rows for k in r.keys())
                if has_tps:
                    col = {
                        "pintok": ("pintok_tps_mean", "pintok_tps_std"),
                        "rust": ("rust_tps_mean", "rust_tps_std"),
                        "python": ("python_tps_mean", "python_tps_std"),
                    }
                    for r in rows:
                        if r.get('dataset') != dataset:
                            continue
                        try:
                            tokens = int(float(r.get('tokens_per_packet', '')))
                        except Exception:
                            continue
                        for mode, (mcol, scol) in col.items():
                            try:
                                mean = float(r.get(mcol)) if r.get(mcol) else None
                                std = float(r.get(scol)) if r.get(scol) else None
                            except Exception:
                                mean, std = None, None
                            if mean is not None:
                                series[mode].append((tokens, mean, std or 0.0))
                else:
                    # Convert from latency summary rows
                    def tf(mu_us: str | None, sd_us: str | None, tokens: int) -> tuple[float | None, float | None]:
                        try:
                            mu = float(mu_us) if mu_us else None
                            sd = float(sd_us) if sd_us else None
                        except Exception:
                            return (None, None)
                        if mu and mu > 0:
                            mean_tps = tokens * 1_000_000.0 / mu
                            std_tps = (tokens * 1_000_000.0 * (sd or 0.0) / (mu * mu))
                            return (mean_tps, std_tps)
                        return (None, None)
                    for r in rows:
                        if r.get('dataset') != dataset:
                            continue
                        try:
                            tokens = int(float(r.get('tokens_per_packet', '')))
                        except Exception:
                            continue
                        dpdk_m, dpdk_s = tf(r.get('avg_pintok_mean_us'), r.get('avg_pintok_std_us'), tokens)
                        rust_m, rust_s = tf(r.get('avg_rust_mean_us'), r.get('avg_rust_std_us'), tokens)
                        py_m, py_s = tf(r.get('avg_python_mean_us'), r.get('avg_python_std_us'), tokens)
                        if dpdk_m is not None:
                            series['pintok'].append((tokens, dpdk_m, dpdk_s or 0.0))
                        if rust_m is not None:
                            series['rust'].append((tokens, rust_m, rust_s or 0.0))
                        if py_m is not None:
                            series['python'].append((tokens, py_m, py_s or 0.0))
                for mode in list(series.keys()):
                    series[mode].sort(key=lambda t: t[0])
                return series

            COLORS = {"python": "#FDB515", "rust": "#002676", "pintok": "#8C1515"}
            LABELS = {"python": "Python", "rust": "Rust", "pintok": "PinTok"}

            model_names = [m for m, rows in model_rows_map.items() if any(r.get('dataset') == dataset for r in rows)]
            if not model_names:
                return False

            # Increase font sizes by 2pt for better readability
            plt.rcParams.update({
                'font.size': 8.5,
                'axes.titlesize': 8.5,
                'axes.labelsize': 8.0,
                'xtick.labelsize': 6.0,
                'ytick.labelsize': 7.0,
                'legend.fontsize': 7.5,
                'axes.linewidth': 0.6,
                'xtick.major.width': 0.5,
                'ytick.major.width': 0.5,
                'xtick.major.size': 2.0,
                'ytick.major.size': 2.0,
                'grid.linewidth': 0.4,
            })

            if args.debug:
                print(f"[DEBUG] Throughput grid for dataset={dataset}, models={model_names}")
            ncols = len(model_names)
            import math
            # Reduce height by 30% (from 2.4 to 1.68) and increase width for legend
            width = 6.2 if ncols >= 3 else max(2.8, 2.2 * ncols)
            fig, axs = plt.subplots(1, ncols, figsize=(width, 1.68))
            if ncols == 1:
                axs = [axs]

            for c, model in enumerate(model_names):
                ax = axs[c]
                series = _build_series(model_rows_map.get(model, []))
                all_tokens = sorted({t for mode in series.values() for (t, _, _) in mode})
                if args.debug:
                    print(f"[DEBUG] Model {model}: token lengths={all_tokens}")
                n_groups = len(all_tokens)
                token_index = {t: i for i, t in enumerate(all_tokens)}
                x_centers = list(range(n_groups))
                ax.set_xticks(x_centers)
                ax.set_xticklabels([f"{t:,}" for t in all_tokens])
                bar_width = 0.26 if n_groups > 1 else 0.28
                offsets = {'python': -bar_width, 'rust': 0.0, 'pintok': bar_width}
                for mode in ["python", "rust", "pintok"]:
                    pts = series.get(mode, [])
                    if not pts:
                        continue
                    xs = [token_index[t] + offsets[mode] for (t, _, _) in pts]
                    ys = [m for (_, m, _) in pts]
                    es = [s for (_, _, s) in pts]
                    # Fix for numpy array comparison
                    has_errors = False
                    try:
                        has_errors = any(float(e) > 0 for e in es)
                    except (TypeError, ValueError) as ex:
                        if args.debug:
                            print(f"[DEBUG] Error checking errors: {ex}, es={es}")
                        has_errors = False
                    except Exception as ex:
                        if args.debug:
                            print(f"[DEBUG] Unexpected error in error check: {ex}, es={es}, type={[type(e) for e in es]}")
                        import traceback
                        traceback.print_exc()
                        has_errors = False

                    ax.bar(
                        xs,
                        ys,
                        width=bar_width,
                        color=COLORS[mode],
                        label=LABELS[mode],
                        linewidth=0.0,
                        yerr=es if has_errors else None,
                        error_kw=dict(ecolor='black', elinewidth=0.7, capsize=2.0, capthick=0.7),
                    )
                ax.grid(True, linestyle='--', alpha=0.25, linewidth=0.4, axis='y')
                ax.set_xlim(-0.6, max(x_centers) + 0.6 if x_centers else 0.4)
                ax.set_title(model)
                if c == 0:
                    ax.set_ylabel("Throughput\n(tokens/s)")
                # Show y tick labels for all subplots
                ax.tick_params(axis='y', which='both', labelleft=True)
                ax.set_xlabel("Tokens")
                # Per-model y-limits based on data points for this model
                import math as _math
                y_min = _math.inf
                y_max = -_math.inf
                for mode in ["python", "rust", "pintok"]:
                    for (_, mean, std) in series.get(mode, []):
                        e = float(std or 0.0)
                        y_min = min(y_min, float(mean) - e)
                        y_max = max(y_max, float(mean) + e)
                if y_min is _math.inf or y_max is -_math.inf:
                    y_min, y_max = (0.0, 1.0)
                span = max(1e-9, y_max - y_min)
                pad = 0.05 * span
                ax.set_ylim(max(0.0, y_min - pad), y_max + pad)
                from matplotlib.ticker import MaxNLocator as _MaxNLocator
                ax.yaxis.set_major_locator(_MaxNLocator(nbins=4))
                ax.grid(True, which='both', axis='y', linestyle='--', alpha=0.25, linewidth=0.4)

            handles, labels = axs[0].get_legend_handles_labels() if (axs is not None and len(axs) > 0) else ([], [])
            # Reorder to Python, Rust, PinTok
            desired = ["Python", "Rust", "PinTok"]
            if len(handles) > 0 and len(labels) > 0:
                map_h = {l: h for h, l in zip(handles, labels)}
                labels = [l for l in desired if l in map_h]
                handles = [map_h[l] for l in labels]
            # Adjust layout for right-side legend
            fig.tight_layout(rect=[0.01, 0.01, 0.88, 0.99])
            if len(handles) > 0 and len(labels) > 0:
                # Place legend on right side, vertically centered
                fig.legend(handles, labels, loc='center left', ncol=1, frameon=False, bbox_to_anchor=(0.90, 0.5), borderaxespad=0.0, handlelength=1.5, handletextpad=0.5, markerscale=1.0, borderpad=0.2)
            os.makedirs(out_dir, exist_ok=True)
            out_path = os.path.join(out_dir, outfile)
            fig.savefig(out_path, bbox_inches='tight', pad_inches=0.02)
            print(f"Saved by-model throughput grid: {out_path}")
            return True
        except Exception as e:
            print(f"Warning: failed to render by-model throughput grid for {dataset}: {e}")
            import traceback
            traceback.print_exc()
            return False

    # Plot-only shortcut: render per-metric plots (linear & log-y) and 3x4 grids, then exit
    if args.plot_only:
        try:
            import subprocess
            import csv as _csv
            plot_script = os.path.join(THIS_DIR, 'plot_results.py')
            # Collected CSVs for by-model grids
            _plot_only_model_csvs: dict[str, str] = {}
            _plot_only_model_throughput_csvs: dict[str, str] = {}

            # For plot-only with --embed-model all, discover which models have results
            if args.embed_model and args.embed_model.lower() == 'all':
                import glob
                discovered_models = []
                for csv_file in glob.glob(os.path.join(test_dir, 'results_summary_*.csv')):
                    basename = os.path.basename(csv_file)
                    if basename.startswith('results_summary_') and basename.endswith('.csv'):
                        # Extract model name from filename
                        model_suffix = basename[len('results_summary_'):-4]
                        if model_suffix:  # Skip the standalone case
                            # Convert back to model name format
                            model_name = model_suffix.replace('_', '/')
                            discovered_models.append(model_name)
                if discovered_models:
                    embed_models = discovered_models
                    print(f"Plot-only: discovered models with results: {embed_models}")

            for em in embed_models:
                suffix = '' if (not em or em.lower() in ('standalone','none')) else ('_' + _sanitize_suffix(em))
                # In plot-only mode, never use args.out_file for embed models
                csv_path = os.path.join(test_dir, f'results_summary{suffix}.csv')
                # Metric plots (linear and log-y)
                for m in ['avg', 'p50', 'p90', 'p99']:
                    base_cmd = [sys.executable, plot_script, '--csv', csv_path, '--metric', m, '--out-dir', args.plot_out_dir]
                    if args.plot_dataset:
                        base_cmd += ['--dataset'] + args.plot_dataset
                    else:
                        base_cmd += ['--all-datasets']
                    print(f"Plot-only: running {' '.join(base_cmd)}")
                    subprocess.run(base_cmd, check=True)
                    log_cmd = base_cmd + ['--log-y']
                    print(f"Plot-only: running {' '.join(log_cmd)}")
                    subprocess.run(log_cmd, check=True)

                # Grid plots: dataset-rows 3x4 for standalone; for embed-model collect for by-model grids
                if not em or em.lower() in ('standalone','none'):
                    grid_script = os.path.join(THIS_DIR, 'plot_grid.py')
                    grid_cmd = [sys.executable, grid_script, '--csv', csv_path, '--out-dir', args.plot_out_dir, '--outfile', 'plot_grid_3x4.pdf']
                    print(f"Plot-only: running {' '.join(grid_cmd)}")
                    subprocess.run(grid_cmd, check=True)
                    grid_cmd_log = [sys.executable, grid_script, '--csv', csv_path, '--out-dir', args.plot_out_dir, '--outfile', 'plot_grid_3x4_logy.pdf', '--log-y']
                    print(f"Plot-only: running {' '.join(grid_cmd_log)}")
                    subprocess.run(grid_cmd_log, check=True)
                else:
                    _plot_only_model_csvs[em] = csv_path

                # Throughput CSV and plots for this suffix
                base = os.path.basename(csv_path)
                tsuffix = ''
                if base.startswith('results_summary') and base.endswith('.csv'):
                    tsuffix = base[len('results_summary'):-4]
                th_csv = os.path.join(args.plot_out_dir, f'throughput_summary{tsuffix}.csv')
                rows = []
                with open(csv_path, 'r') as f:
                    reader = _csv.DictReader(f)
                    for r in reader:
                        try:
                            dataset = r.get('dataset')
                            tokens = int(float(r.get('tokens_per_packet', '0')))
                            trials = int(float(r.get('trials', '0'))) if r.get('trials') else None
                            def tf(mean_us_str, std_us_str):
                                try:
                                    mu = float(mean_us_str) if mean_us_str else None
                                    sd = float(std_us_str) if std_us_str else None
                                except Exception:
                                    return (None, None)
                                if mu and mu > 0:
                                    mean_tps = tokens * 1_000_000.0 / mu
                                    std_tps = (tokens * 1_000_000.0 * sd / (mu * mu)) if (sd is not None) else 0.0
                                    return (mean_tps, std_tps)
                                return (None, None)
                            dpdk_mean, dpdk_std = tf(r.get('avg_pintok_mean_us'), r.get('avg_pintok_std_us'))
                            rust_mean, rust_std = tf(r.get('avg_rust_mean_us'), r.get('avg_rust_std_us'))
                            py_mean, py_std = tf(r.get('avg_python_mean_us'), r.get('avg_python_std_us'))
                            def ratio(a, b):
                                try:
                                    return a / b if (a is not None and b) else None
                                except ZeroDivisionError:
                                    return None
                            rows.append({
                                'dataset': dataset,
                                'trials': trials,
                                'tokens_per_packet': tokens,
                                'pintok_tps_mean': dpdk_mean,
                                'pintok_tps_std': dpdk_std,
                                'rust_tps_mean': rust_mean,
                                'rust_tps_std': rust_std,
                                'python_tps_mean': py_mean,
                                'python_tps_std': py_std,
                                'improve_vs_rust': ratio(dpdk_mean, rust_mean),
                                'improve_vs_python': ratio(dpdk_mean, py_mean),
                            })
                        except Exception:
                            continue
                with open(th_csv, 'w') as f:
                    headers = [
                        'dataset', 'trials', 'tokens_per_packet',
                        'pintok_tps_mean', 'pintok_tps_std',
                        'rust_tps_mean', 'rust_tps_std', 'improve_vs_rust',
                        'python_tps_mean', 'python_tps_std', 'improve_vs_python',
                    ]
                    f.write(','.join(headers) + '\n')
                    for r in rows:
                        def fmt(x):
                            if x is None:
                                return ''
                            return f"{x:.2f}" if isinstance(x, float) else str(x)
                        f.write(','.join(fmt(r.get(h)) for h in headers) + '\n')
                print(f"Plot-only: wrote throughput summary CSV to {th_csv}")

                # Throughput plots: default dataset-columns grid only for standalone
                if not em or em.lower() in ('standalone','none'):
                    plot_th = os.path.join(THIS_DIR, 'plot_throughput_results.py')
                    for ds in (args.plot_dataset if args.plot_dataset else ['openwebtext', 'code', 'multilingual']):
                        cmd = [sys.executable, plot_th, '--csv', th_csv, '--dataset', ds, '--out-dir', args.plot_out_dir]
                        print(f"Plot-only: running {' '.join(cmd)}")
                        subprocess.run(cmd, check=True)
                    th_grid = os.path.join(THIS_DIR, 'plot_throughput_grid.py')
                    grid_cmd = [sys.executable, th_grid, '--csv', th_csv, '--out-dir', args.plot_out_dir]
                    print(f"Plot-only: running {' '.join(grid_cmd)}")
                    subprocess.run(grid_cmd, check=True)
                else:
                    _plot_only_model_throughput_csvs[em] = th_csv
            # After looping models, render by-model grids per dataset if any
            if _plot_only_model_csvs:
                import csv as _csv2
                discovered_ds = set()
                for _em, cpath in _plot_only_model_csvs.items():
                    try:
                        with open(cpath, 'r') as f:
                            for row in _csv2.DictReader(f):
                                ds = row.get('dataset')
                                if ds:
                                    discovered_ds.add(ds)
                    except Exception:
                        pass
                datasets_to_plot = args.plot_dataset if args.plot_dataset else sorted(discovered_ds)
                for ds in datasets_to_plot:
                    m2rows = {}
                    for _em, cpath in _plot_only_model_csvs.items():
                        try:
                            with open(cpath, 'r') as f:
                                m2rows[_em] = list(_csv2.DictReader(f))
                        except Exception:
                            pass
                    out_dir = os.path.join(args.plot_out_dir, 'by_model')
                    ok = _plot_latency_grid_by_model(m2rows, ds, out_dir, f'plot_grid_by_model_{ds}.pdf', log_y=False)
                    if ok:
                        _ = _plot_latency_grid_by_model(m2rows, ds, out_dir, f'plot_grid_by_model_{ds}_logy.pdf', log_y=True)
                    if _plot_only_model_throughput_csvs:
                        m2trows = {}
                        for _em, tpath in _plot_only_model_throughput_csvs.items():
                            try:
                                with open(tpath, 'r') as f:
                                    m2trows[_em] = list(_csv2.DictReader(f))
                            except Exception:
                                pass
                        if m2trows:
                            _plot_throughput_grid_by_model(m2trows, ds, out_dir, f'plot_throughput_by_model_{ds}.pdf')
            return
        except Exception as e:
            print(f"Plot-only failed: {e}")
            return

    # Validate required inputs when not plotting only
    # When an embedding model is used, token length can be set via --tokens-per-packet;
    # otherwise a single length equal to model's max (capped by receiver) is used.
    if not args.mode or not args.dataset:
        print("Error: --mode and --dataset are required unless --plot-only is used")
        return
    if (not args.embed_model) or (args.embed_model.lower() in ('standalone','none')):
        # Allow per-model overrides or config defaults to supply tokens for standalone mode
        has_override_for_standalone = any(k in tpp_overrides_map for k in ('standalone', 'none', '*'))
        has_cfg_default = bool(tpp_default_from_cfg)
        if not args.tokens_per_packet and not has_override_for_standalone and not has_cfg_default:
            print("Error: provide --tokens-per-packet, --tpp-override for 'standalone', or --tpp-config with defaults for standalone tokenizer tests")
            return
    else:
        # For embed-model runs, use --tokens-per-packet values if provided (filtering unsupported lengths)
        if args.tokens_per_packet and args.debug:
            print("Info: --embed-model set; using --tokens-per-packet values for embed runs (filtering unsupported lengths).")

    # Guard when writing into existing test directory unless --override
    if os.path.isdir(test_dir) and not args.override:
        print(f"Refusing to overwrite existing test folder: {test_dir}")
        print("Pass --override to allow writing new results into this folder.")
        return
    # Ensure results directories exist
    os.makedirs(test_dir, exist_ok=True)

    # Expand modes/datasets once
    modes = args.mode
    if any(m.lower() == 'all' for m in modes):
        modes = ['dpdk', 'rust', 'python']
    datasets = args.dataset
    if any(d.lower() == 'all' for d in datasets):
        expanded = []
        for d in datasets:
            if d.lower() == 'all':
                expanded += ['openwebtext', 'multilingual', 'code']
            else:
                expanded.append(d)
        # Deduplicate while preserving order
        seen = set()
        datasets = [x for x in expanded if not (x in seen or seen.add(x))]
    token_counts = args.tokens_per_packet

    # Helper to resolve max token length for an embedding model
    def _resolve_model_max_len(model_id: str) -> int:
        # Default conservative fallback
        default_len = 512
        if not model_id or model_id.lower() in ('standalone','none'):
            return default_len
        if _HAS_TRANSFORMERS:
            try:
                cfg = AutoConfig.from_pretrained(model_id, trust_remote_code=False)
                for key in ('max_position_embeddings', 'max_seq_length', 'max_sequence_length'):
                    val = getattr(cfg, key, None)
                    if isinstance(val, int) and val > 0:
                        return int(val)
            except Exception as e:
                print(f"Warning: failed to load config for {model_id}: {e}. Using {default_len}.")
        else:
            print("Warning: transformers not available; using default model length 512")
        return default_len

    # Prepare result rows per embed model
    _model_csvs_all: dict[str, str] = {}
    _model_throughput_csvs_all: dict[str, str] = {}
    for _i, _em in enumerate(embed_models):
        if args.debug:
            print(f"[DEBUG] Starting model {_i+1}/{len(embed_models)}: {_em}")
        out_rows = []
        cache_rows = []
        throughput_rows = []
        # Prepare per-model output path early so we can append rows incrementally
        suffix = '' if (not _em or _em.lower() in ('standalone','none')) else ('_' + _sanitize_suffix(_em))
        out_file = os.path.join(test_dir, f'results_summary{suffix}.csv')
        os.makedirs(os.path.dirname(out_file), exist_ok=True)
        # If overriding, start with a clean file so we don't mix previous runs
        try:
            if args.override and os.path.exists(out_file):
                os.remove(out_file)
        except Exception:
            pass
        # CSV header for latency summary
        _latency_headers = [
            'dataset', 'trials', 'tokens_per_packet',
            # Average latency group
            'avg_pintok_mean_us', 'avg_pintok_std_us',
            'avg_rust_mean_us', 'avg_rust_std_us', 'avg_improve_vs_rust',
            'avg_python_mean_us', 'avg_python_std_us', 'avg_improve_vs_python',
            'avg_embed_mean_us', 'avg_embed_std_us',
            # Encoder latency (mean of per-run encode_mean)
            'avg_encode_pintok_mean_us', 'avg_encode_pintok_std_us',
            'avg_encode_rust_mean_us', 'avg_encode_rust_std_us',
            'avg_encode_python_mean_us', 'avg_encode_python_std_us',
            # P50 group
            'p50_pintok_mean_us', 'p50_pintok_std_us',
            'p50_rust_mean_us', 'p50_rust_std_us', 'p50_improve_vs_rust',
            'p50_python_mean_us', 'p50_python_std_us', 'p50_improve_vs_python',
            'p50_embed_mean_us', 'p50_embed_std_us',
            # P90 group
            'p90_pintok_mean_us', 'p90_pintok_std_us',
            'p90_rust_mean_us', 'p90_rust_std_us', 'p90_improve_vs_rust',
            'p90_python_mean_us', 'p90_python_std_us', 'p90_improve_vs_python',
            'p90_embed_mean_us', 'p90_embed_std_us',
            # P99 group
            'p99_pintok_mean_us', 'p99_pintok_std_us',
            'p99_rust_mean_us', 'p99_rust_std_us', 'p99_improve_vs_rust',
            'p99_python_mean_us', 'p99_python_std_us', 'p99_improve_vs_python',
            'p99_embed_mean_us', 'p99_embed_std_us',
        ]
        def _fmt_val(x):
            if x is None:
                return ''
            return f"{x:.2f}" if isinstance(x, float) else str(x)

        def _csv_row_values(row: dict) -> list[str]:
            """Return header-ordered, formatted values for a row."""
            return [_fmt_val(row.get(h)) for h in _latency_headers]

        def _print_csv_row(row: dict, prefix: str = "[csv]") -> None:
            """Print the exact CSV line that will be appended (for debugging)."""
            try:
                line = ','.join(_csv_row_values(row))
                if args.debug:
                    print(f"{prefix} {line}")
            except Exception:
                pass
        def _append_latency_row(row: dict):
            # Write header if file missing/empty, then append row
            try:
                need_header = (not os.path.exists(out_file)) or (os.path.getsize(out_file) == 0)
                try:
                    if args.debug:
                        print(f"[debug] append start tokens={row.get('tokens_per_packet')} header_needed={need_header}")
                except Exception:
                    pass
                with open(out_file, 'a') as f:
                    if need_header:
                        f.write(','.join(_latency_headers) + '\n')
                    f.write(','.join(_fmt_val(row.get(h)) for h in _latency_headers) + '\n')
                try:
                    if args.debug:
                        print(f"[debug] wrote row tokens={row.get('tokens_per_packet')} header={need_header} path={out_file}")
                except Exception:
                    pass
            except Exception as _ioe:
                # Don't fail the run on I/O issues; plotting may proceed from memory rows
                try:
                    if args.debug:
                        print(f"[debug] append failed tokens={row.get('tokens_per_packet')} error={_ioe}")
                except Exception:
                    pass

        def _finalize_latency_csv(rows: list[dict]):
            """Merge any existing rows with newly collected ones and rewrite CSV once.
            This guards against any accidental overwrites elsewhere and ensures all
            token lengths appear in the final file.
            """
            try:
                existing: list[dict] = []
                if os.path.exists(out_file) and os.path.getsize(out_file) > 0:
                    import csv as _csv
                    with open(out_file, 'r') as rf:
                        for r in _csv.DictReader(rf):
                            existing.append(r)
                # build map by (dataset, tokens)
                def _key(r: dict):
                    try:
                        t = int(float(r.get('tokens_per_packet', '0')))
                    except Exception:
                        t = 0
                    return (r.get('dataset'), t)
                merged: dict[tuple, dict] = {}
                for r in existing:
                    merged[_key(r)] = r
                for r in rows:
                    merged[_key(r)] = {k: ('' if (v is None) else (f"{v:.2f}" if isinstance(v, float) else v)) for k, v in r.items()}
                # order by dataset then tokens ascending
                ordered = sorted(merged.values(), key=lambda r: (r.get('dataset'), int(float(r.get('tokens_per_packet', '0') or 0))))
                if args.debug:
                    try:
                        keys = []
                        for r in ordered:
                            try:
                                keys.append((r.get('dataset'), int(float(r.get('tokens_per_packet') or '0'))))
                            except Exception:
                                keys.append((r.get('dataset'), r.get('tokens_per_packet')))
                        print(f"[debug] finalize: existing_rows={len(existing)} in_mem_rows={len(rows)} merged_rows={len(ordered)} keys={keys}")
                    except Exception:
                        pass
                with open(out_file, 'w') as wf:
                    wf.write(','.join(_latency_headers) + '\n')
                    for r in ordered:
                        wf.write(','.join(str(r.get(h, '')) for h in _latency_headers) + '\n')
            except Exception:
                # silent guard; we still have incremental appends above
                pass
        # Determine token counts per packet for this model (with per-model overrides)
        model_id = (_em or 'standalone')
        is_standalone = (not _em) or (_em.lower() in ('standalone','none'))
        model_max = None if is_standalone else _resolve_model_max_len(model_id)
        # Precedence: --tpp-override model -> override '*' -> --tpp-config default -> --tokens-per-packet -> fallback
        requested_list = None
        if model_id in tpp_overrides_map:
            requested_list = tpp_overrides_map.get(model_id)
        elif (not is_standalone) and ('*' in tpp_overrides_map):
            requested_list = tpp_overrides_map.get('*')
        elif is_standalone and any(k in tpp_overrides_map for k in ('standalone','none','*')):
            requested_list = tpp_overrides_map.get('standalone') or tpp_overrides_map.get('none') or tpp_overrides_map.get('*')
        elif tpp_default_from_cfg:
            requested_list = tpp_default_from_cfg
        else:
            requested_list = args.tokens_per_packet

        if requested_list:
            # Preserve provided order; dedupe; enforce receiver cap; warn when exceeding model's max
            token_counts = []
            seen_tc = set()
            for item in requested_list:
                try:
                    v = int(item)
                except Exception:
                    continue
                if v <= 0 or v in seen_tc:
                    continue
                if v > int(args.max_seq_len):
                    print(f"Skipping tokens-per-packet={v}: exceeds receiver max ({args.max_seq_len})")
                    continue
                if (not is_standalone) and model_max and v > model_max:
                    print(f"Warning: tokens-per-packet={v} exceeds {model_id} declared max ({model_max}); will attempt and record DP results anyway.")
                token_counts.append(v)
                seen_tc.add(v)
            if not token_counts:
                print(f"No valid tokens-per-packet values for {model_id}; skipping this model's runs.")
        else:
            # Default: single test at the effective max length (embed models only)
            if is_standalone:
                # Should not reach here due to earlier guard, but fallback to 512
                token_counts = [min(int(args.max_seq_len), 512)]
            else:
                eff_len = min(model_max or int(args.max_seq_len), int(args.max_seq_len))
                token_counts = [eff_len]
                print(f"Embed model {model_id}: using tokens-per-packet={eff_len} (model_max={model_max}, receiver_cap={args.max_seq_len})")

        # Log the resolved token list once per embed model
        if args.debug:
            try:
                print(f"Resolved tokens-per-packet for {_em or 'standalone'}: {token_counts}")
                print(f"[DEBUG] About to process datasets for model {_em}")
            except Exception:
                pass

        for dataset in datasets:
            for tokens_pp in token_counts:
                if args.debug:
                    print(f"\n=== Combination: dataset={dataset}, tokens/packet={tokens_pp} ===")
                # Write a placeholder row early so the length is visible even if later aggregation fails
                try:
                    _append_latency_row({'dataset': dataset, 'trials': args.runs, 'tokens_per_packet': tokens_pp})
                    if args.debug:
                        print(f"[debug] appended placeholder row tokens={tokens_pp}")
                except Exception as _ph_ex:
                    try:
                        if args.debug:
                            print(f"[debug] placeholder append failed tokens={tokens_pp} err={_ph_ex}")
                    except Exception:
                        pass
                # Accumulate per-mode stats across trials with distinct dataset-offsets
                per_mode_runs: dict[str, list[dict]] = {m: [] for m in modes}
                per_mode_embeddings: dict[str, list] = {m: [] for m in modes}  # Store embeddings per mode
                # Optional persistent Rust pipeline to keep cache across trials
                persistent_rust = None
                for trial_idx in range(args.runs):
                    # Use the same set of packets across trials for stability
                    dataset_offset = 0
                    if args.debug:
                        print(f"  Trial {trial_idx+1}/{args.runs}")
                    for mode in modes:
                        print(f"    - {mode} (tokens={tokens_pp})")
                        if mode == 'rust' and args.disable_rust_cache_clearing_per_trial:
                            # Build or reuse a single Rust pipeline across trials (unfair, by request)
                            if persistent_rust is None:
                                persistent_rust = PythonRustPipeline(
                                    model_name=args.model,
                                    encoder_type=args.encoder,
                                    encoder_model=_em,
                                    force_cpu=args.cpu,
                                    debug=args.debug,
                                    batch_size=args.batch_size,
                                    enable_batch=effective_enable_batch,
                                    use_fast=True,
                                    latency_mode=args.latency_mode,
                                    warmup=effective_warmup,
                                    disable_cache=args.disable_cache,
                                    pin_core=None,
                                )
                            # Bind socket for this trial (tokenizer stays loaded across start/stop)
                            persistent_rust.start()
                            stats = run_single_experiment(
                                mode=mode,
                                tokenizer=args.tokenizer,
                                model=args.model,
                                max_packets=args.max_packets,
                                enable_batch=effective_enable_batch,
                                batch_size=args.batch_size,
                                encoder=args.encoder,
                                embed_model=_em,
                                force_cpu=args.cpu,
                                debug=args.debug,
                                debug_bert=args.debug_bert,
                                dataset=dataset,
                                timeout_seconds=args.timeout,
                                tokens_per_packet=tokens_pp,
                                max_seq_len=args.max_seq_len,
                                delay_ms=args.delay_ms,
                                latency_mode=args.latency_mode,
                                pin_core=effective_pin_core,
                                rt_prio=effective_rt_prio,
                                dpdk_log_level=effective_dpdk_log,
                                use_sudo=(not args.disable_sudo),
                                use_msg_id_header=(not args.no_msg_id_header),
                                allow_non_isolated=args.allow_non_isolated,
                                dataset_offset=dataset_offset,
                                disable_cache=args.disable_cache,
                                warmup=effective_warmup,
                                external_pipeline=persistent_rust,
                                keep_pipeline_alive=True,
                                clear_rust_cache_before_trial=False,
                                embed_packet_delay_ms=args.embed_packet_delay_ms,
                                embed_settle_ms=getattr(args, 'embed_settle_ms', 30000),
                                embed_missing_tolerance=getattr(args, 'embed_missing_tolerance', 2),
                                dup_chunks_override=args.dup_chunks,
                            )
                            # Unbind to free port for other modes in this trial
                            try:
                                persistent_rust.stop()
                            except Exception:
                                pass
                        else:
                            stats = run_single_experiment(
                                mode=mode,
                                tokenizer=args.tokenizer,
                                model=args.model,
                                max_packets=args.max_packets,
                                enable_batch=effective_enable_batch,
                                batch_size=args.batch_size,
                                encoder=args.encoder,
                                embed_model=_em,
                                force_cpu=args.cpu,
                                debug=args.debug,
                                debug_bert=args.debug_bert,
                                dataset=dataset,
                                timeout_seconds=args.timeout,
                                tokens_per_packet=tokens_pp,
                                max_seq_len=args.max_seq_len,
                                delay_ms=args.delay_ms,
                                latency_mode=args.latency_mode,
                                pin_core=effective_pin_core,
                                rt_prio=effective_rt_prio,
                                dpdk_log_level=effective_dpdk_log,
                                use_sudo=(not args.disable_sudo),
                                use_msg_id_header=(not args.no_msg_id_header),
                                allow_non_isolated=args.allow_non_isolated,
                                dataset_offset=dataset_offset,
                                disable_cache=args.disable_cache,
                                warmup=effective_warmup,
                                embed_packet_delay_ms=args.embed_packet_delay_ms,
                                embed_settle_ms=getattr(args, 'embed_settle_ms', 30000),
                                embed_missing_tolerance=getattr(args, 'embed_missing_tolerance', 2),
                                dup_chunks_override=args.dup_chunks,
                            )
                        count = stats.get('count', 0)
                        if args.debug:
                            print(f"      Processed packets: {count} (mode={mode}, tokens={tokens_pp})")
                        try:
                            print(f"      Stats keys={list(stats.keys())} mean={stats.get('mean')} median={stats.get('median')} p90={stats.get('p90')} p99={stats.get('p99')}")
                        except Exception:
                            pass
                        per_mode_runs[mode].append(stats)
                        # Store embeddings for comparison (only store from first trial to avoid duplication)
                        if args.compare_embeddings and trial_idx == 0 and 'embeddings' in stats:
                            per_mode_embeddings[mode].extend(stats['embeddings'])
                # Stop persistent Rust pipeline at end of combination
                if persistent_rust is not None:
                    try:
                        persistent_rust.stop()
                    except Exception:
                        pass
                    persistent_rust = None

                # Aggregate across trials per mode (per tokens_per_packet combination)
                combo_results = {}
                for mode in modes:
                    runs_list = per_mode_runs.get(mode, [])
                    if args.debug:
                        print(f"[DEBUG] Mode {mode} has {len(runs_list)} runs with counts: {[r.get('count', 0) for r in runs_list]}")
                    agg = aggregate_runs(runs_list)
                    mode_label = mode if mode != 'dpdk' else (f"PinTok/{args.tokenizer}" if args.tokenizer else "PinTok")
                    if args.debug:
                        print(f"    Summary [{mode_label}] (tokens={tokens_pp}) -> mean: {fmt_us(agg['mean_us']['mean'])}, p50: {fmt_us(agg['p50_us']['mean'])}, p90: {fmt_us(agg['p90_us']['mean'])}, p99: {fmt_us(agg['p99_us']['mean'])}")
                    combo_results[mode] = agg

                # Build one wide CSV row per (dataset, tokens_pp)
                dpdk_ref = combo_results.get('dpdk')
                rust_ref = combo_results.get('rust')
                py_ref   = combo_results.get('python')
                embed_ref = combo_results.get('embed')
                try:
                    if args.debug:
                        print(f"[debug] build row tokens={tokens_pp}")
                        print(f"  dpdk_ref: runs={dpdk_ref.get('runs') if dpdk_ref else None}, has_data={dpdk_ref is not None}")
                        print(f"  rust_ref: runs={rust_ref.get('runs') if rust_ref else None}, has_data={rust_ref is not None}")
                        print(f"  py_ref: runs={py_ref.get('runs') if py_ref else None}, has_data={py_ref is not None}")
                        print(f"  embed_ref: runs={embed_ref.get('runs') if embed_ref else None}, has_data={embed_ref is not None}")
                except Exception:
                    pass

                def stat_pair(agg, metric):
                    if not agg:
                        return (None, None)
                    d = agg.get(metric, {})
                    if not d:
                        return (None, None)
                    return (d.get('mean'), d.get('std'))

                def ratio(a, b):
                    try:
                        return a / b if (a is not None and b) else None
                    except ZeroDivisionError:
                        return None

                try:
                    if args.debug:
                        print(f"[debug] Starting stat extraction for tokens={tokens_pp}")
                    d_mean, d_std = stat_pair(dpdk_ref, 'mean_us')
                    if args.debug:
                        print(f"[debug] dpdk stats: mean={d_mean}, std={d_std}")
                    d_p50,  d_p50_std  = stat_pair(dpdk_ref, 'p50_us')
                    d_p90,  d_p90_std  = stat_pair(dpdk_ref, 'p90_us')
                    d_p99,  d_p99_std  = stat_pair(dpdk_ref, 'p99_us')

                    r_mean, r_std = stat_pair(rust_ref, 'mean_us')
                    if args.debug:
                        print(f"[debug] rust stats: mean={r_mean}, std={r_std}")
                    r_p50,  r_p50_std  = stat_pair(rust_ref, 'p50_us')
                    r_p90,  r_p90_std  = stat_pair(rust_ref, 'p90_us')
                    r_p99,  r_p99_std  = stat_pair(rust_ref, 'p99_us')

                    p_mean, p_std = stat_pair(py_ref, 'mean_us')
                    if args.debug:
                        print(f"[debug] python stats: mean={p_mean}, std={p_std}")
                    p_p50,  p_p50_std  = stat_pair(py_ref, 'p50_us')
                    p_p90,  p_p90_std  = stat_pair(py_ref, 'p90_us')
                    p_p99,  p_p99_std  = stat_pair(py_ref, 'p99_us')

                    e_mean, e_std = stat_pair(embed_ref, 'mean_us') if embed_ref else (None, None)
                    if args.debug:
                        print(f"[debug] embed stats: mean={e_mean}, std={e_std}")
                    e_p50,  e_p50_std  = stat_pair(embed_ref, 'p50_us') if embed_ref else (None, None)
                    e_p90,  e_p90_std  = stat_pair(embed_ref, 'p90_us') if embed_ref else (None, None)
                    e_p99,  e_p99_std  = stat_pair(embed_ref, 'p99_us') if embed_ref else (None, None)
                    if args.debug:
                        print(f"[debug] All stats extracted successfully for tokens={tokens_pp}")
                except Exception as stat_ex:
                    if args.debug:
                        print(f"[debug] ERROR extracting stats for tokens={tokens_pp}: {stat_ex}")
                    import traceback
                    traceback.print_exc()
                    raise

                # Encoder latency aggregation (mean of per-run encode_mean)
                def enc_pair(mode_key: str):
                    runs = per_mode_runs.get(mode_key, [])
                    if not runs:  # Handle missing modes gracefully
                        return (None, None)
                    vals = []
                    for s in runs:
                        v = s.get('encode_mean')
                        if v is None:
                            continue
                        try:
                            vals.append(float(v))
                        except Exception:
                            # Ignore non-numeric values
                            continue
                    if not vals:
                        return (None, None)
                    try:
                        m = statistics.mean(vals)
                    except Exception:
                        m = None
                    try:
                        sd = statistics.stdev(vals) if len(vals) > 1 else 0.0
                    except Exception:
                        sd = None
                    return (m, sd)

                try:
                    if args.debug:
                        print(f"[debug] Starting encoder pair extraction for tokens={tokens_pp}")
                    e_d_mean, e_d_std = enc_pair('dpdk')
                    if args.debug:
                        print(f"[debug] dpdk encoder: mean={e_d_mean}, std={e_d_std}")
                    e_r_mean, e_r_std = enc_pair('rust')
                    if args.debug:
                        print(f"[debug] rust encoder: mean={e_r_mean}, std={e_r_std}")
                    e_p_mean, e_p_std = enc_pair('python')
                    if args.debug:
                        print(f"[debug] python encoder: mean={e_p_mean}, std={e_p_std}")
                    if args.debug:
                        print(f"[debug] Encoder pair extraction complete for tokens={tokens_pp}")
                except Exception as enc_ex:
                    if args.debug:
                        print(f"[debug] ERROR in encoder pair extraction for tokens={tokens_pp}: {enc_ex}")
                    import traceback
                    traceback.print_exc()
                    raise

                if args.debug:
                    print(f"[debug] building row dict tokens={tokens_pp}")
                    print(f"[debug] enc means: dpdk={e_d_mean}, rust={e_r_mean}, python={e_p_mean}")

                try:
                    row = {
                        'dataset': dataset,
                        'trials': args.runs,
                        'tokens_per_packet': tokens_pp,
                    # Average columns
                    'avg_pintok_mean_us': d_mean,
                    'avg_pintok_std_us': d_std,
                    'avg_rust_mean_us': r_mean,
                    'avg_rust_std_us': r_std,
                    'avg_improve_vs_rust': ratio(r_mean, d_mean) if d_mean and r_mean else None,
                    'avg_python_mean_us': p_mean,
                    'avg_python_std_us': p_std,
                    'avg_improve_vs_python': ratio(p_mean, d_mean) if d_mean and p_mean else None,
                    'avg_embed_mean_us': e_mean,
                    'avg_embed_std_us': e_std,
                    # Encoder latency (mean of per-run encoder means)
                    'avg_encode_pintok_mean_us': e_d_mean,
                    'avg_encode_pintok_std_us': e_d_std,
                    'avg_encode_rust_mean_us': e_r_mean,
                    'avg_encode_rust_std_us': e_r_std,
                    'avg_encode_python_mean_us': e_p_mean,
                    'avg_encode_python_std_us': e_p_std,
                    # P50 columns
                    'p50_pintok_mean_us': d_p50,
                    'p50_pintok_std_us': d_p50_std,
                    'p50_rust_mean_us': r_p50,
                    'p50_rust_std_us': r_p50_std,
                    'p50_improve_vs_rust': ratio(r_p50, d_p50) if d_p50 and r_p50 else None,
                    'p50_python_mean_us': p_p50,
                    'p50_python_std_us': p_p50_std,
                    'p50_improve_vs_python': ratio(p_p50, d_p50) if d_p50 and p_p50 else None,
                    'p50_embed_mean_us': e_p50,
                    'p50_embed_std_us': e_p50_std,
                    # P90 columns
                    'p90_pintok_mean_us': d_p90,
                    'p90_pintok_std_us': d_p90_std,
                    'p90_rust_mean_us': r_p90,
                    'p90_rust_std_us': r_p90_std,
                    'p90_improve_vs_rust': ratio(r_p90, d_p90) if d_p90 and r_p90 else None,
                    'p90_python_mean_us': p_p90,
                    'p90_python_std_us': p_p90_std,
                    'p90_improve_vs_python': ratio(p_p90, d_p90) if d_p90 and p_p90 else None,
                    'p90_embed_mean_us': e_p90,
                    'p90_embed_std_us': e_p90_std,
                    # P99 columns
                    'p99_pintok_mean_us': d_p99,
                    'p99_pintok_std_us': d_p99_std,
                    'p99_rust_mean_us': r_p99,
                    'p99_rust_std_us': r_p99_std,
                    'p99_improve_vs_rust': ratio(r_p99, d_p99) if d_p99 and r_p99 else None,
                    'p99_python_mean_us': p_p99,
                    'p99_python_std_us': p_p99_std,
                    'p99_improve_vs_python': ratio(p_p99, d_p99) if d_p99 and p_p99 else None,
                    'p99_embed_mean_us': e_p99,
                    'p99_embed_std_us': e_p99_std,
                    }
                    try:
                        if args.debug:
                            print(f"[debug] row built tokens={tokens_pp} d_mean={d_mean} e_mean={e_mean}")
                    except Exception:
                        pass
                    out_rows.append(row)

                    # Print then append this row to CSV so earlier tokens are preserved
                    _print_csv_row(row, prefix=f"[csv tokens={tokens_pp}]")
                    _append_latency_row(row)
                    try:
                        if args.debug:
                            print(f"Appended latency row: dataset={dataset}, tokens={tokens_pp}")
                    except Exception:
                        pass

                except Exception as _rb_err:
                    try:
                        if args.debug:
                            print(f"[debug] row build/append failed tokens={tokens_pp} err={_rb_err}")
                        import traceback
                        traceback.print_exc()
                    except Exception:
                        pass
                    fallback_row = {
                        'dataset': dataset,
                        'trials': args.runs,
                        'tokens_per_packet': tokens_pp,
                    }
                    out_rows.append(fallback_row)
                    _print_csv_row(fallback_row, prefix=f"[csv-fallback tokens={tokens_pp}]")
                    _append_latency_row(fallback_row)

                # Build throughput row (tokens/s) aggregated across trials per mode
                try:
                    def tps_list(mode_key: str) -> list[float]:
                        lst = []
                        for s in per_mode_runs.get(mode_key, []):
                            mu = s.get('mean')
                            cnt = s.get('count', 0)
                            if cnt and mu and mu > 0:
                                lst.append((tokens_pp * 1_000_000.0) / mu)
                        return lst

                    def agg(lst: list[float]) -> tuple[float | None, float | None]:
                        if not lst:
                            return (None, None)
                        m = statistics.mean(lst)
                        sd = statistics.stdev(lst) if len(lst) > 1 else 0.0
                        return (m, sd)

                    d_tps_mean, d_tps_std = agg(tps_list('dpdk'))
                    r_tps_mean, r_tps_std = agg(tps_list('rust'))
                    p_tps_mean, p_tps_std = agg(tps_list('python'))

                    def ratio_safe(a, b):
                        try:
                            return a / b if (a is not None and b) else None
                        except ZeroDivisionError:
                            return None

                    print(f"[DEBUG] Adding throughput row for {_em}: dataset={dataset}, tokens={tokens_pp}")
                    throughput_rows.append({
                        'dataset': dataset,
                        'trials': args.runs,
                        'tokens_per_packet': tokens_pp,
                        'pintok_tps_mean': d_tps_mean,
                        'pintok_tps_std': d_tps_std,
                        'rust_tps_mean': r_tps_mean,
                        'rust_tps_std': r_tps_std,
                        'python_tps_mean': p_tps_mean,
                        'python_tps_std': p_tps_std,
                        'improve_vs_rust': ratio_safe(d_tps_mean, r_tps_mean),
                        'improve_vs_python': ratio_safe(d_tps_mean, p_tps_mean),
                    })
                except Exception as e:
                    print(f"[DEBUG] Failed to add throughput row for tokens={tokens_pp}: {e}")
                    import traceback
                    traceback.print_exc()
                    try:
                        print(f"Appended fallback latency row: dataset={dataset}, tokens={tokens_pp}")
                    except Exception:
                        pass

            # Additional fraction-of-tokenization CSV (embed-model mode only)
            if (_em and _em.lower() not in ('standalone','none')):
                def frac_agg(mode_key: str):
                    vals = [s.get('e2e_fraction_mean') for s in per_mode_runs.get(mode_key, []) if s.get('e2e_fraction_mean') is not None]
                    if not vals:
                        return (None, None)
                    return (statistics.mean(vals), statistics.stdev(vals) if len(vals) > 1 else 0.0)
                f_d_mean, f_d_std = frac_agg('dpdk')
                f_r_mean, f_r_std = frac_agg('rust')
                f_p_mean, f_p_std = frac_agg('python')
                fraction_rows = locals().get('fraction_rows', [])
                fraction_rows.append({
                    'dataset': dataset,
                    'tokens_per_packet': tokens_pp,
                    'frac_pintok_mean': f_d_mean,
                    'frac_pintok_std': f_d_std,
                    'frac_rust_mean': f_r_mean,
                    'frac_rust_std': f_r_std,
                    'frac_python_mean': f_p_mean,
                    'frac_python_std': f_p_std,
                })
                locals()['fraction_rows'] = fraction_rows

                # Build cache row for DPDK
                try:
                    dpdk_runs_list = per_mode_runs.get('dpdk', [])
                    filt = [s for s in dpdk_runs_list if s.get('count', 0) > 0]
                    sum_lookups = sum(s.get('cache_lookups', 0) for s in filt)
                    sum_hits = sum(s.get('cache_hits', 0) for s in filt)
                    sum_inserts = sum(s.get('cache_inserts', 0) for s in filt)
                    sum_ins_fail = sum(s.get('cache_insert_fails', 0) for s in filt)
                    sum_skip_long = sum(s.get('cache_skip_longkey', 0) for s in filt)
                    sum_skip_oversize = sum(s.get('cache_skip_oversize', 0) for s in filt)
                    ratio = (sum_hits / sum_lookups) if sum_lookups else None
                    cache_rows.append({
                        'dataset': dataset,
                        'tokens_per_packet': tokens_pp,
                        'runs': len(filt),
                        'dpdk_cache_lookups': sum_lookups,
                        'dpdk_cache_hits': sum_hits,
                        'dpdk_cache_inserts': sum_inserts,
                        'dpdk_cache_insert_fails': sum_ins_fail,
                        'dpdk_cache_skip_longkey': sum_skip_long,
                        'dpdk_cache_skip_oversize': sum_skip_oversize,
                        'dpdk_cache_hit_ratio': ratio,
                    })
                except Exception:
                    pass

                # Compare embeddings if enabled and we have multiple modes with embeddings
                if args.compare_embeddings and len(per_mode_embeddings) > 1 and any(len(embs) > 0 for embs in per_mode_embeddings.values()):
                    try:
                        # Convert lists to numpy arrays if needed
                        embeddings_for_comparison = {}
                        for mode, emb_list in per_mode_embeddings.items():
                            if emb_list:
                                # Ensure all embeddings are numpy arrays
                                numpy_embs = []
                                for emb in emb_list:
                                    if isinstance(emb, np.ndarray):
                                        numpy_embs.append(emb)
                                    elif hasattr(emb, 'numpy'):
                                        numpy_embs.append(emb.numpy())
                                    else:
                                        numpy_embs.append(np.array(emb))
                                if numpy_embs:
                                    embeddings_for_comparison[mode] = numpy_embs

                        if len(embeddings_for_comparison) > 1:
                            # Save comparison results
                            comparison_csv = os.path.join(os.path.dirname(out_file),
                                                          f'embedding_comparison_{tokens_pp}tokens{suffix}.csv')
                            comparison_results = compare_embeddings(embeddings_for_comparison, comparison_csv)

                            # Print summary statistics
                            print(f"\n=== Embedding Comparison (tokens={tokens_pp}) ===")
                            for (mode1, mode2), sims in comparison_results.items():
                                if sims:
                                    mean_sim = np.mean(sims)
                                    min_sim = np.min(sims)
                                    max_sim = np.max(sims)
                                    print(f"{mode1} vs {mode2}: mean={mean_sim:.6f}, min={min_sim:.6f}, max={max_sim:.6f}")
                    except Exception as e:
                        print(f"Warning: Failed to compare embeddings: {e}")


        if args.debug:
            print(f"[DEBUG] Finished processing datasets for model {_em}")

        # Write CSV output (latency) for this embed model
        # Ensure the final CSV contains all rows merged and ordered
        if args.debug:
            print(f"[DEBUG] About to finalize CSV for model {_em}")
        _finalize_latency_csv(out_rows)
        print(f"\nSaved summary CSV to {out_file}")

        # Write overall embedding comparison summary if enabled
        if args.compare_embeddings and embed_models and any(em and em.lower() not in ('standalone','none') for em in embed_models):
            try:
                summary_path = os.path.join(os.path.dirname(out_file), f'embedding_comparison_summary{suffix}.txt')
                with open(summary_path, 'w') as f:
                    f.write("Embedding Comparison Summary\n")
                    f.write("=" * 50 + "\n\n")
                    f.write(f"Model: {_em}\n")
                    f.write(f"Modes compared: {', '.join(modes)}\n")
                    f.write(f"Token counts tested: {', '.join(map(str, token_counts))}\n\n")
                    f.write("Note: Cosine similarity values range from -1 to 1.\n")
                    f.write("Values close to 1 indicate high similarity.\n")
                    f.write("Values > 0.95 generally indicate functionally equivalent embeddings.\n")
                print(f"Saved embedding comparison summary to {summary_path}")
            except Exception as e:
                print(f"Warning: Failed to write embedding comparison summary: {e}")

        # For embed-model runs, keep track for by-model grids later
        if _em and _em.lower() not in ('standalone','none'):
            _model_csvs_all[_em] = out_file
            if args.debug:
                print(f"[DEBUG] Added model '{_em}' to _model_csvs_all. Current models: {list(_model_csvs_all.keys())}")

        # Write throughput summary CSV
        try:
            # Derive suffix from out_file name (results_summary{suffix}.csv)
            base = os.path.basename(out_file)
            tsuffix = ''
            if base.startswith('results_summary') and base.endswith('.csv'):
                tsuffix = base[len('results_summary'):-4]
            th_csv = os.path.join(os.path.dirname(out_file), f'throughput_summary{tsuffix}.csv')
            if args.debug:
                print(f"[DEBUG] Writing throughput CSV for {_em} with {len(throughput_rows)} rows")
            with open(th_csv, 'w') as f:
                headers = [
                    'dataset', 'trials', 'tokens_per_packet',
                    'pintok_tps_mean', 'pintok_tps_std',
                    'rust_tps_mean', 'rust_tps_std', 'improve_vs_rust',
                    'python_tps_mean', 'python_tps_std', 'improve_vs_python',
                    'embed_tps_mean', 'embed_tps_std',
                ]
                f.write(','.join(headers) + '\n')
                for r in throughput_rows:
                    def fmt(x):
                        if x is None:
                            return ''
                        return f"{x:.2f}" if isinstance(x, float) else str(x)
                    f.write(','.join(fmt(r.get(h)) for h in headers) + '\n')
            print(f"Saved throughput summary CSV to {th_csv}")
            if _em and _em.lower() not in ('standalone','none'):
                _model_throughput_csvs_all[_em] = th_csv
                if args.debug:
                    print(f"[DEBUG] Added model '{_em}' to _model_throughput_csvs_all. Current models: {list(_model_throughput_csvs_all.keys())}")
        except Exception as e:
            print(f"Warning: failed to write throughput_summary.csv: {e}")

        # Write DPDK cache summary CSV (per dataset and token count)
        try:
            cache_csv = os.path.join(os.path.dirname(out_file), f'cache_summary{suffix}.csv')
            with open(cache_csv, 'w') as f:
                headers = [
                    'dataset', 'tokens_per_packet', 'runs',
                    'dpdk_cache_lookups', 'dpdk_cache_hits', 'dpdk_cache_inserts',
                    'dpdk_cache_insert_fails', 'dpdk_cache_skip_longkey', 'dpdk_cache_skip_oversize',
                    'dpdk_cache_hit_ratio'
                ]
                f.write(','.join(headers) + '\n')
                for r in cache_rows:
                    def fmt(x):
                        if x is None:
                            return ''
                        if isinstance(x, float):
                            return f"{x:.4f}"
                        return str(x)
                    f.write(','.join(fmt(r.get(h)) for h in headers) + '\n')
            print(f"Saved DPDK cache summary CSV to {cache_csv}")
        except Exception as e:
            print(f"Warning: failed to write cache_summary.csv: {e}")

        # Write tokenization fraction CSV (embed-model only)
        try:
            if (_em and _em.lower() not in ('standalone','none')):
                frac_rows = locals().get('fraction_rows', [])
                frac_csv = os.path.join(os.path.dirname(out_file), f'fraction_summary{suffix}.csv')
                with open(frac_csv, 'w') as f:
                    headers = [
                        'dataset', 'tokens_per_packet',
                        'frac_pintok_mean', 'frac_pintok_std',
                        'frac_rust_mean', 'frac_rust_std',
                        'frac_python_mean', 'frac_python_std',
                    ]
                    f.write(','.join(headers) + '\n')
                    for r in frac_rows:
                        def fmt(x):
                            if x is None:
                                return ''
                            return f"{x:.4f}" if isinstance(x, float) else str(x)
                        f.write(','.join(fmt(r.get(h)) for h in headers) + '\n')
                print(f"Saved tokenization fraction CSV to {frac_csv}")
                # Simple plot: fraction vs tokens per packet for each dataset
                try:
                    import matplotlib.pyplot as plt
                    from collections import defaultdict
                    data = defaultdict(list)
                    for r in frac_rows:
                        data[r['dataset']].append((int(r['tokens_per_packet']), r['frac_pintok_mean'], r['frac_rust_mean'], r['frac_python_mean']))
                    for ds, pts in data.items():
                        pts.sort(key=lambda x: x[0])
                        xs = [p[0] for p in pts]
                        dp = [p[1] for p in pts]
                        ru = [p[2] for p in pts]
                        py = [p[3] for p in pts]
                        plt.figure(figsize=(4.0, 2.6))
                        plt.plot(xs, dp, marker='o', label='PinTok fraction')
                        plt.plot(xs, ru, marker='s', label='Rust fraction')
                        plt.plot(xs, py, marker='^', label='Python fraction')
                        plt.ylim(0.0, 1.0)
                        plt.xlabel('Tokens per packet')
                        plt.ylabel('Tokenization fraction of E2E')
                        plt.title(f'Tokenization Fraction vs Tokens ({ds})')
                        plt.grid(True, linestyle='--', alpha=0.3)
                        plt.legend(fontsize=8)
                        out_pdf = os.path.join(os.path.dirname(out_file), f'fraction_{ds}.pdf')
                        plt.tight_layout()
                        plt.savefig(out_pdf)
                        plt.close()
                        print(f"Saved tokenization fraction plot: {out_pdf}")
                except Exception as e:
                    print(f"Warning: failed to plot fraction: {e}")
        except Exception as e:
            print(f"Warning: failed to write fraction_summary.csv: {e}")

        # Auto-generate plots for avg, p50, p90, p99 (linear & log-y) using the CSV we just wrote
        try:
            import subprocess
            plot_script = os.path.join(THIS_DIR, 'plot_results.py')
            metrics = ['avg', 'p50', 'p90', 'p99']
            # Route plots into a model-suffixed subdirectory to avoid collisions
            base = os.path.basename(out_file)
            suffix = ''
            if base.startswith('results_summary') and base.endswith('.csv'):
                suffix = base[len('results_summary'):-4]
            plot_out_dir = args.plot_out_dir
            if suffix:
                plot_out_dir = os.path.join(plot_out_dir, suffix.lstrip('_'))
            os.makedirs(plot_out_dir, exist_ok=True)
            print(f"[DEBUG] Creating plots for model='{_em}', suffix='{suffix}', plot_out_dir='{plot_out_dir}'")
            # Use the datasets we iterated if available; fall back to all in CSV
            plot_datasets = datasets if datasets else None
            for m in metrics:
                cmd = [sys.executable, plot_script, '--csv', out_file, '--metric', m, '--out-dir', plot_out_dir]
                if plot_datasets:
                    cmd += ['--dataset'] + plot_datasets
                else:
                    cmd += ['--all-datasets']
                print(f"Generating {m} plots: {' '.join(cmd)}")
                subprocess.run(cmd, check=True)
                # Log-y variant
                cmd_log = cmd + ['--log-y']
                print(f"Generating {m} plots (log-y): {' '.join(cmd_log)}")
                subprocess.run(cmd_log, check=True)

            # Generate grids: dataset-rows 3x4 for standalone; by-model grids handled after loop
            if (not _em) or (_em.lower() in ('standalone', 'none')):
                grid_script = os.path.join(THIS_DIR, 'plot_grid.py')
                grid_cmd = [sys.executable, grid_script, '--csv', out_file, '--out-dir', plot_out_dir, '--outfile', f'plot_grid_3x4{suffix}.pdf']
                print(f"Generating 3x4 grid plot: {' '.join(grid_cmd)}")
                subprocess.run(grid_cmd, check=True)
                grid_cmd_log = [sys.executable, grid_script, '--csv', out_file, '--out-dir', plot_out_dir, '--outfile', f'plot_grid_3x4_logy{suffix}.pdf', '--log-y']
                print(f"Generating 3x4 grid plot (log-y): {' '.join(grid_cmd_log)}")
                subprocess.run(grid_cmd_log, check=True)

            # Generate throughput plots from the throughput_summary.csv written above
            tsuffix = suffix
            th_csv = os.path.join(os.path.dirname(out_file), f'throughput_summary{tsuffix}.csv')
            try:
                plot_th = os.path.join(THIS_DIR, 'plot_throughput_results.py')
                # Use datasets we iterated
                th_datasets = datasets if datasets else ['openwebtext', 'code', 'multilingual']
                if (not _em) or (_em.lower() in ('standalone', 'none')):
                    for ds in th_datasets:
                        cmd = [sys.executable, plot_th, '--csv', th_csv, '--dataset', ds, '--out-dir', plot_out_dir]
                        print(f"Generating throughput plot: {' '.join(cmd)}")
                        subprocess.run(cmd, check=True)
                    th_grid = os.path.join(THIS_DIR, 'plot_throughput_grid.py')
                    grid_cmd = [sys.executable, th_grid, '--csv', th_csv, '--out-dir', plot_out_dir]
                    print(f"Generating throughput grid: {' '.join(grid_cmd)}")
                    subprocess.run(grid_cmd, check=True)
            except Exception as e:
                print(f"Warning: failed to plot throughput: {e}")
            # Generate cache hit ratio plots by default (per dataset)
            try:
                import matplotlib.pyplot as plt
                import csv
                from collections import defaultdict
                # Use model-suffixed cache summary if present
                base = os.path.basename(out_file)
                csuffix = ''
                if base.startswith('results_summary') and base.endswith('.csv'):
                    csuffix = base[len('results_summary'):-4]
                cache_csv = os.path.join(os.path.dirname(out_file), f'cache_summary{csuffix}.csv')
                data = defaultdict(list)
                with open(cache_csv, 'r') as f:
                    reader = csv.DictReader(f)
                    for row in reader:
                        dataset = row['dataset']
                        try:
                            tpp = int(row['tokens_per_packet'])
                            ratio = float(row['dpdk_cache_hit_ratio']) if row['dpdk_cache_hit_ratio'] else None
                        except Exception:
                            continue
                        if ratio is not None:
                            data[dataset].append((tpp, ratio))
                out_dir = args.plot_out_dir if args.plot_out_dir else os.path.dirname(out_file)
                os.makedirs(out_dir, exist_ok=True)
                for dataset, pts in data.items():
                    pts.sort(key=lambda x: x[0])
                    xs = [p[0] for p in pts]
                    ys = [p[1] for p in pts]
                    plt.figure(figsize=(4.0, 2.6))
                    plt.plot(xs, ys, marker='o', linewidth=1.2, markersize=3.0, color='#1f77b4', label='DPDK cache hit ratio')
                    plt.title(f'Cache Hit Ratio vs Tokens ({dataset})')
                    plt.xlabel('Tokens per packet')
                    plt.ylabel('Cache hit ratio')
                    plt.ylim(0.0, 1.0)
                    plt.grid(True, linestyle='--', alpha=0.3, linewidth=0.5)
                    plt.legend(loc='best', fontsize=8)
                    out_pdf = os.path.join(out_dir, f'cache_hit_ratio_{dataset}.pdf')
                    plt.tight_layout()
                    plt.savefig(out_pdf)
                    plt.close()
                    print(f"Saved cache hit ratio plot: {out_pdf}")
                # Plot inserts vs tokens as well
                data_ins = defaultdict(list)
                with open(cache_csv, 'r') as f:
                    reader = csv.DictReader(f)
                    for row in reader:
                        dataset = row['dataset']
                        try:
                            tpp = int(row['tokens_per_packet'])
                            ins = int(row['dpdk_cache_inserts']) if row['dpdk_cache_inserts'] else 0
                        except Exception:
                            continue
                        data_ins[dataset].append((tpp, ins))
                for dataset, pts in data_ins.items():
                    pts.sort(key=lambda x: x[0])
                    xs = [p[0] for p in pts]
                    ys = [p[1] for p in pts]
                    plt.figure(figsize=(4.0, 2.6))
                    plt.plot(xs, ys, marker='o', linewidth=1.2, markersize=3.0, color='#2ca02c', label='DPDK cache inserts')
                    plt.title(f'Cache Inserts vs Tokens ({dataset})')
                    plt.xlabel('Tokens per packet')
                    plt.ylabel('Total inserts (sum over runs)')
                    plt.grid(True, linestyle='--', alpha=0.3, linewidth=0.5)
                    plt.legend(loc='best', fontsize=8)
                    out_pdf = os.path.join(out_dir, f'cache_inserts_{dataset}.pdf')
                    plt.tight_layout()
                    plt.savefig(out_pdf)
                    plt.close()
                    print(f"Saved cache inserts plot: {out_pdf}")
            except Exception as e:
                print(f"Warning: failed to plot cache summary: {e}")
        except Exception as e:
            print(f"Warning: automatic plotting failed: {e}")

    # If this was the final embed-model iteration, render by-model grids per dataset
    try:
        if (_i == len(embed_models) - 1) and _model_csvs_all:
            import csv as _csv3
            # Determine datasets tested (prefer the ones we iterated)
            tested_ds = datasets if datasets else []
            if not tested_ds:
                discovered = set()
                for _emk, _cpath in _model_csvs_all.items():
                    try:
                        with open(_cpath, 'r') as f:
                            for row in _csv3.DictReader(f):
                                ds = row.get('dataset')
                                if ds:
                                    discovered.add(ds)
                    except Exception:
                        pass
                tested_ds = sorted(discovered)
            out_dir = os.path.join(args.plot_out_dir, 'by_model')
            # Latency by-model grids
            model_rows_map_all = {}
            for _emk, _cpath in _model_csvs_all.items():
                try:
                    with open(_cpath, 'r') as f:
                        model_rows_map_all[_emk] = list(_csv3.DictReader(f))
                except Exception:
                    pass
            for ds in tested_ds:
                ok = _plot_latency_grid_by_model(model_rows_map_all, ds, out_dir, f'plot_grid_by_model_{ds}.pdf', log_y=False)
                if ok:
                    _ = _plot_latency_grid_by_model(model_rows_map_all, ds, out_dir, f'plot_grid_by_model_{ds}_logy.pdf', log_y=True)
            # Throughput by-model grids
            if _model_throughput_csvs_all:
                model_trows_map_all = {}
                for _emk, _tpath in _model_throughput_csvs_all.items():
                    try:
                        with open(_tpath, 'r') as f:
                            model_trows_map_all[_emk] = list(_csv3.DictReader(f))
                    except Exception:
                        pass
                print(f"[DEBUG] By-model throughput: models in map: {list(model_trows_map_all.keys())}")
                for ds in tested_ds:
                    _plot_throughput_grid_by_model(model_trows_map_all, ds, out_dir, f'plot_throughput_by_model_{ds}.pdf')
    except Exception as e:
        print(f"Warning: failed to render by-model grids: {e}")


if __name__ == '__main__':
    main()
