#!/usr/bin/env python3
"""
Tokenization Pipeline Runner with Optional BERT Encoding and Batch Processing

This script runs tokenization pipelines with optional BERT encoding and batch processing:
1. DPDK Pipeline: Packet → DPDK C code → Tokenization → Optional BERT → Python retrieval → Print
2. Rust Pipeline: Packet → Python → HuggingFace fast (Rust) tokenizer → Optional BERT → Print
3. Python Pipeline: Packet → Python → HuggingFace Python tokenizer → Optional BERT → Print

Both pipelines listen on port 6000.

Usage:
    # For Rust tokenizer pipeline (infinite packets):
    python run_tokenizer.py --mode rust --warmup
    
    # For Rust pipeline with TinyBERT encoding:
    python run_tokenizer.py --mode rust --warmup --encoder tinybert
    
    # For Rust pipeline with custom timeout:
    python run_tokenizer.py --mode rust --warmup --timeout 60
    
    # For Rust pipeline with batch processing:
    python run_tokenizer.py --mode rust --warmup --enable-batch --batch-size 64
    
    # For Rust pipeline with limited packets:
    python run_tokenizer.py --mode rust --warmup --max-packets 5
    
    # For Python tokenizer (slow) with limited packets:
    python run_tokenizer.py --mode python --max-packets 5
    
    # For DPDK BPE (infinite packets, GPT-2 default):
    python run_tokenizer.py --mode dpdk --tokenizer bpe
    
    # For DPDK BPE with TinyBERT encoding:
    python run_tokenizer.py --mode dpdk --tokenizer bpe --encoder tinybert
    
    # For DPDK BPE with batch processing:
    python run_tokenizer.py --mode dpdk --tokenizer bpe --enable-batch --batch-size 128
    
    # For DPDK BPE with different model:
    python run_tokenizer.py --mode dpdk --tokenizer bpe --model modernbert-large
    
    # For Rust tokenizer with different model and TinyBERT:
    python run_tokenizer.py --mode rust --model modernbert-large --warmup --encoder tinybert
    
    # For DPDK with limited packets:
    python run_tokenizer.py --mode dpdk --tokenizer bpe --max-packets 10
"""

import argparse
import socket
import time
from typing import Dict, List, Optional

# Ensure single-core default unless explicitly disabled via CLI flag
import sys
import os

# If the user passes the flag, disable the one-core limit before importing utils
_disable_one_core = "--disable-one-core-limit" in sys.argv
if _disable_one_core:
    # Signal to utils/pipelines.py to skip single-core defaults
    os.environ["ONE_CORE_LIMIT_DISABLED"] = "1"
    # Opt into multi-threading for tokenizers
    os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get("TOKENIZERS_PARALLELISM", "true")
    os.environ.pop("RAYON_RS_NUM_THREADS", None)
    os.environ.pop("RAYON_NUM_THREADS", None)

# Import utilities from the utils package (after setting env)
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../src/python'))
from utils import (
    LatencyStats,
    BERTEncoder,
    DPDKPipeline,
    PythonRustPipeline
)

def handle_packet_processing(pipeline,
                            max_packets: int = None,
                            mode_name: str = "Pipeline",
                            verbose: bool = True,
                            timeout_seconds: int = 30,
                            print_ids_only: bool = False,
                            debug: bool = False):
    """
    Common packet processing logic for all pipeline modes.
    
    Args:
        pipeline: Pipeline instance with get_result() or process_packet() method
        max_packets: Maximum packets to process (None = infinite)
        mode_name: Name for logging purposes
        verbose: Whether to show individual tokenization results
        timeout_seconds: Timeout in seconds for waiting for packets
    
    Returns:
        tuple: (packet_count, collected_token_ids)
    """
    packet_count = 0
    last_activity = time.time()
    collected_token_ids = []  # Collect token IDs for saving later
    
    if print_ids_only:
        print(f"{mode_name} started in IDS_ONLY mode. Send packets to port 6000.")
        if max_packets:
            print(f"Will automatically stop after processing {max_packets} packets.")
        print("=== TOKEN_IDS_START ===")
        import sys
        sys.stdout.flush()
    else:
        print(f"{mode_name} started. Send packets to port 6000. batch size: {getattr(pipeline, 'batch_size', 128)}")
        if max_packets:
            print(f"Will automatically stop after processing {max_packets} packets.")
        print("Press Ctrl+C to stop and show statistics.")
    
    while True:
        # Check packet limit
        if max_packets and packet_count >= max_packets:
            print(f"\nProcessed {max_packets} packets. Stopping automatically...")
            break
        
        # Check timeout (only when max_packets is set)
        current_time = time.time()
        if max_packets and current_time - last_activity > timeout_seconds:
            print(f"\nTimeout: No packets received for {timeout_seconds} seconds. Stopping...")
            break
        
        # Get result from pipeline
        result = None
        if hasattr(pipeline, 'get_result'):
            # DPDK pipeline
            result = pipeline.get_result(timeout=5.0 if max_packets else 10.0)
        elif hasattr(pipeline, 'process_packet'):
            # Python pipeline
            result = pipeline.process_packet()
        
        if result:
            packet_count += 1
            last_activity = current_time
            
            # Collect token IDs for saving later
            token_ids_str = result.get('TOKEN_IDS', '')
            if token_ids_str:
                collected_token_ids.append(token_ids_str)
            
            # Handle print_ids_only mode
            if print_ids_only:
                print(token_ids_str)
            else:
                # Format output
                if verbose and debug:
                    text = result.get('ORIGINAL_TEXT', '')
                    tokens = result.get('NUM_TOKENS', 0)
                    token_str = result.get('TOKENS', '')
                    bert_shape = result.get('BERT_EMBEDDINGS_SHAPE', '')
                    # Truncate tokens if too long for display
                    if len(token_str) > 60:
                        token_preview = token_str[:57] + "..."
                    else:
                        token_preview = token_str
                    base_info = f"Text: '{text}', Tokens: {tokens}, Token Preview: [{token_preview}]"
                    if bert_shape:
                        base_info += f", BERT: {bert_shape}"
                    if max_packets:
                        print(f"[{packet_count}/{max_packets}] {mode_name} Result - {base_info}")
                    else:
                        print(f"{mode_name} Result - {base_info}")
    
    return packet_count, collected_token_ids


def handle_batch_processing(pipeline,
                            max_packets: int = None,
                            mode_name: str = "Pipeline",
                            verbose: bool = True,
                            timeout_seconds: int = 30,
                            debug: bool = False):
    """
    Common batch processing logic for all pipeline modes.
    
    Args:
        pipeline: Pipeline instance with get_batch_results() or process_batch_packets() method
        max_packets: Maximum packets to process (None = infinite)
        mode_name: Name for logging purposes
        verbose: Whether to show individual tokenization results
        timeout_seconds: Timeout in seconds for waiting for packets
    
    Returns:
        tuple: (packet_count, collected_token_ids)
    """
    packet_count = 0
    batch_count = 0
    last_activity = time.time()
    collected_token_ids = []  # Collect token IDs for saving later
    
    print(f"{mode_name} started. Send packets to port 6000.")
    if max_packets:
        print(f"Will automatically stop after processing {max_packets} packets.")
    print("Press Ctrl+C to stop and show statistics.")
    print(f"Batch size: {getattr(pipeline, 'batch_size', 128)}")
    
    while True:
        # Check packet limit
        if max_packets and packet_count >= max_packets:
            print(f"\nProcessed {max_packets} packets. Stopping automatically...")
            break
        
        # Check timeout (only when max_packets is set)
        current_time = time.time()
        if max_packets and current_time - last_activity > timeout_seconds:
            print(f"\nTimeout: No packets received for {timeout_seconds} seconds. Stopping...")
            break
        
        # Get batch results from pipeline
        batch_results = []
        if hasattr(pipeline, 'get_batch_results'):
            # DPDK pipeline
            batch_results = pipeline.get_batch_results(timeout=5.0 if max_packets else 10.0)
        elif hasattr(pipeline, 'process_batch_packets'):
            # Python pipeline
            batch_results = pipeline.process_batch_packets(timeout=5.0 if max_packets else 10.0)
        
        if batch_results:
            batch_count += 1
            batch_size = len(batch_results)
            packet_count += batch_size
            last_activity = current_time
            
            # Collect token IDs from batch
            for result in batch_results:
                token_ids_str = result.get('TOKEN_IDS', '')
                if token_ids_str:
                    collected_token_ids.append(token_ids_str)
            
            # Print batch summary
            if verbose and debug:
                print(f"\nBatch {batch_count} processed: {batch_size} packets")
                
                # Show individual results in the batch
                for i, result in enumerate(batch_results):
                    text = result.get('ORIGINAL_TEXT', '')
                    tokens = result.get('NUM_TOKENS', 0)
                    token_str = result.get('TOKENS', '')
                    bert_info = ""
                    
                    if result.get('BERT_BATCH_SIZE'):
                        bert_info = f", BERT Batch: {result.get('BERT_BATCH_SIZE')} (pos {i})"
                    elif result.get('BERT_EMBEDDINGS_SHAPE'):
                        bert_info = f", BERT: {result.get('BERT_EMBEDDINGS_SHAPE')}"
                    
                    # Truncate tokens if too long for display
                    if len(token_str) > 40:
                        token_preview = token_str[:37] + "..."
                    else:
                        token_preview = token_str
                    
                    batch_info = f"  [{i+1}/{batch_size}] Text: '{text[:30]}...', Tokens: {tokens}, Preview: [{token_preview}]{bert_info}"
                    print(batch_info)
            else:
                # Just show batch summary
                avg_tokens = sum(r.get('NUM_TOKENS', 0) for r in batch_results) / len(batch_results)
                print(f"Batch {batch_count}: {batch_size} packets, avg {avg_tokens:.1f} tokens")
    
    return packet_count, collected_token_ids


def save_token_ids(token_ids_list: List[str], mode: str, tokenizer_type: str = ""):
    """
    Save collected token IDs to a file.
    
    Args:
        token_ids_list: List of token ID strings (each string contains space-separated IDs for one packet)
        mode: Pipeline mode ("dpdk", "rust", or "python")
        tokenizer_type: Tokenizer type or model name for non-DPDK modes
    """
    if not token_ids_list:
        return
    
    # Construct filename based on mode and tokenizer/model type
    if mode == "dpdk":
        filename = f"dpdk_{tokenizer_type}_output.txt"
    elif mode == "rust":
        model_name = tokenizer_type.replace("/", "_") if tokenizer_type else "default"
        filename = f"rust_{model_name}_output.txt"
    elif mode == "python":
        model_name = tokenizer_type.replace("/", "_") if tokenizer_type else "default"
        filename = f"python_{model_name}_output.txt"
    else:
        # Backward compatibility with legacy name
        model_name = tokenizer_type.replace("/", "_") if tokenizer_type else "default"
        filename = f"python-rust_{model_name}_output.txt"
    
    # Write token IDs to file (one line per packet)
    with open(filename, 'w') as f:
        for token_ids in token_ids_list:
            f.write(token_ids + '\n')
    
    print(f"\nToken IDs saved to {filename} ({len(token_ids_list)} packets)")


def format_latency(value_us: float) -> str:
    """Format latency value with appropriate unit (ms for > 1000 us)"""
    if value_us >= 1000:
        return f"{value_us / 1000:.2f} ms"
    else:
        return f"{value_us:.2f} us"


def print_pipeline_stats(pipeline, pipeline_name: str):
    """Print statistics for a pipeline."""
    stats = pipeline.latency_stats.stats()
    is_batch_mode = getattr(pipeline.latency_stats, 'is_batch_mode', False)
    
    # Debug: check if breakdown stats exist
    import os
    if os.environ.get("DEBUG_STATS"):
        print(f"DEBUG: Stats keys for {pipeline_name}: {list(stats.keys())}")
        if hasattr(pipeline, 'latency_stats'):
            print(f"  - tokenize_times count: {len(pipeline.latency_stats.tokenize_times)}")
            if pipeline.latency_stats.tokenize_times:
                print(f"  - tokenize_times sample: {pipeline.latency_stats.tokenize_times[:3]}")
    
    print(f"\n{pipeline_name} Statistics:")
    
    # Ensure output is flushed immediately
    import sys
    sys.stdout.flush()
    
    # Add batch mode indicator
    if is_batch_mode:
        print("  Mode: Batch Processing (per-packet amortized)")
    
    # Print main statistics with appropriate units
    for key, value in stats.items():
        if key == "count":
            print(f"  {key}: {value:.0f}")
        elif key not in ["parse_mean", "extract_mean", "tokenize_mean", "encode_mean"]:
            print(f"  {key}: {format_latency(value)}")
    
    # Print component breakdown
    if "tokenize_mean" in stats:
        print(f"\n  Latency Breakdown (mean per packet):")
        import sys
        sys.stdout.flush()
        
        if is_batch_mode:
            # For batch mode, only show tokenization and encoding
            print(f"    Tokenize:  {format_latency(stats['tokenize_mean']):>12}  (tokenization - amortized)")
            
            if "encode_mean" in stats and stats['encode_mean'] > 0:
                print(f"    Encode:    {format_latency(stats['encode_mean']):>12}  (BERT encoding - amortized)")
                total_breakdown = stats['tokenize_mean'] + stats['encode_mean']
                tokenize_pct = (stats['tokenize_mean'] / total_breakdown) * 100
                encode_pct = (stats['encode_mean'] / total_breakdown) * 100
                print(f"    Total:     {format_latency(total_breakdown):>12}  (actual processing per packet)")
                print(f"    Breakdown: {tokenize_pct:5.1f}% tokenize, {encode_pct:5.1f}% encode")
            else:
                print(f"    Total:     {format_latency(stats['tokenize_mean']):>12}  (actual processing per packet)")
        else:
            # For non-batch mode, show all components
            print(f"    Parse:     {format_latency(stats.get('parse_mean', 0)):>12}  (packet header parsing)")
            print(f"    Extract:   {format_latency(stats.get('extract_mean', 0)):>12}  (text extraction/decoding)")
            print(f"    Tokenize:  {format_latency(stats['tokenize_mean']):>12}  (tokenization algorithm)")
            
            # Show BERT encoding if available
            if "encode_mean" in stats and stats['encode_mean'] > 0:
                print(f"    Encode:    {format_latency(stats['encode_mean']):>12}  (BERT encoding)")
                total_breakdown = stats.get('parse_mean', 0) + stats.get('extract_mean', 0) + stats['tokenize_mean'] + stats['encode_mean']
            else:
                total_breakdown = stats.get('parse_mean', 0) + stats.get('extract_mean', 0) + stats['tokenize_mean']
                
            print(f"    Total:     {format_latency(total_breakdown):>12}  (sum of components)")
            
            # Show percentage breakdown
            if total_breakdown > 0:
                parse_pct = (stats.get('parse_mean', 0) / total_breakdown) * 100
                extract_pct = (stats.get('extract_mean', 0) / total_breakdown) * 100
                tokenize_pct = (stats['tokenize_mean'] / total_breakdown) * 100
                
                if "encode_mean" in stats and stats['encode_mean'] > 0:
                    encode_pct = (stats['encode_mean'] / total_breakdown) * 100
                    print(f"    Breakdown: {parse_pct:5.1f}% parse, {extract_pct:5.1f}% extract, {tokenize_pct:5.1f}% tokenize, {encode_pct:5.1f}% encode")
                else:
                    print(f"    Breakdown: {parse_pct:5.1f}% parse, {extract_pct:5.1f}% extract, {tokenize_pct:5.1f}% tokenize")
        
        # Ensure all breakdown stats are flushed
        import sys
        sys.stdout.flush()

def run_dpdk_mode(tokenizer_type: str = "bpe",
                  max_packets: int = None,
                  model: str = "modernbert-base",
                  encoder_type: str = None,
                  verbose: bool = True,
                  force_cpu: bool = False,
                  debug_bert: bool = False,
                  debug: bool = False,
                  batch_size: int = 128,
                  enable_batch: bool = False,
                  timeout_seconds: int = 30,
                  print_ids_only: bool = False,
                  pin_core: int | None = None,
                  rt_prio: int | None = None,
                  dpdk_log_level: str | None = None,
                  use_sudo: bool = False,
                  allow_non_isolated: bool = False,
                  disable_cache: bool = False):
    """Run DPDK pipeline only"""
    pipeline = DPDKPipeline(tokenizer_type=tokenizer_type,
                            model=model,
                            encoder_type=encoder_type,
                            force_cpu=force_cpu,
                            debug_bert=debug_bert,
                            debug=debug,
                            batch_size=batch_size,
                            enable_batch=enable_batch,
                            pin_core=pin_core,
                            rt_prio=rt_prio,
                            dpdk_log_level=dpdk_log_level,
                            use_sudo=use_sudo,
                            allow_non_isolated=allow_non_isolated,
                            disable_cache=disable_cache)

    collected_token_ids = []
    try:
        pipeline.start()
        if enable_batch:
            packet_count, collected_token_ids = handle_batch_processing(pipeline, max_packets, f"DPDK {tokenizer_type} (Batch)", verbose, timeout_seconds, debug)
        else:
            packet_count, collected_token_ids = handle_packet_processing(pipeline, max_packets, f"DPDK {tokenizer_type}", verbose, timeout_seconds, print_ids_only, debug)
    except KeyboardInterrupt:
        print("\nStopping DPDK pipeline...")
    finally:
        if not print_ids_only:
            print("\nFinalizing DPDK pipeline...")
            import sys
            sys.stdout.flush()
        pipeline.stop()
        if not print_ids_only:
            print_pipeline_stats(pipeline, "DPDK Pipeline")
            # Save token IDs after profiling is complete
            save_token_ids(collected_token_ids, "dpdk", tokenizer_type)
            print("DPDK pipeline statistics complete.")
            sys.stdout.flush()

def run_python_rust_mode(model: str = "modernbert-base",
                         enable_warmup: bool = False,
                         max_packets: int = None,
                         encoder_type: str = None,
                         verbose: bool = True,
                         force_cpu: bool = False,
                         debug: bool = False,
                         batch_size: int = 128,
                         enable_batch: bool = False,
                         timeout_seconds: int = 30,
                         print_ids_only: bool = False,
                         use_fast: bool = True,
                         pipeline_label: str = "Rust",
                         disable_cache: bool = False):
    """Run HuggingFace tokenizer pipeline (fast Rust or Python)"""
    pipeline = PythonRustPipeline(model_name=model, encoder_type=encoder_type, force_cpu=force_cpu, debug=debug, batch_size=batch_size, enable_batch=enable_batch, use_fast=use_fast, warmup=enable_warmup, disable_cache=disable_cache)
    
    collected_token_ids = []
    try:
        pipeline.start()
        if enable_warmup:
            print("Warmup enabled - model should be pre-initialized")
        if enable_batch:
            packet_count, collected_token_ids = handle_batch_processing(pipeline, max_packets, f"{pipeline_label} (Batch)", verbose, timeout_seconds, debug)
        else:
            packet_count, collected_token_ids = handle_packet_processing(pipeline, max_packets, f"{pipeline_label}", verbose, timeout_seconds, print_ids_only, debug)
    except KeyboardInterrupt:
        print(f"\nStopping {pipeline_label} pipeline...")
    finally:
        if not print_ids_only:
            print(f"\nFinalizing {pipeline_label} pipeline...")
            import sys
            sys.stdout.flush()
        pipeline.stop()
        if not print_ids_only:
            print_pipeline_stats(pipeline, f"{pipeline_label} Pipeline")
            # Save token IDs after profiling is complete
            save_token_ids(collected_token_ids, pipeline_label.lower(), model)
            print(f"{pipeline_label} pipeline statistics complete.")
            sys.stdout.flush()

def main():
    # Clean up any hanging tokenizer processes before starting
    from cleanup_tokenizers import find_tokenizer_processes, kill_processes
    hanging_pids = find_tokenizer_processes()
    if hanging_pids:
        print(f"Cleaning up {len(hanging_pids)} hanging tokenizer process(es)...")
        kill_processes(hanging_pids)
    
    parser = argparse.ArgumentParser(description="Tokenization Pipeline Runner")
    parser.add_argument("--mode", choices=["dpdk", "rust", "python"], required=True,
                       help="Pipeline mode to run")
    parser.add_argument("--model", default="gpt2", 
                       help="Model for both pipelines (gpt2 default; can use modernbert-base, modernbert-large, etc.)")
    parser.add_argument("--tokenizer", choices=["simple", "wordpiece", "bpe"], default="bpe",
                       help="DPDK tokenizer type to use")
    parser.add_argument("--max-packets", type=int, default=None,
                       help="Maximum number of packets to process before stopping (default: infinite)")
    parser.add_argument("--warmup", action="store_true",
                       help="Enable model warmup to avoid initialization overhead")
    parser.add_argument("--verbose", action="store_true",
                       help="Show individual tokenization results (default: only show statistics)")
    parser.add_argument("--encoder", choices=["tinybert"], default=None,
                       help="Enable BERT encoding after tokenization (tinybert)")
    parser.add_argument("--cpu", action="store_true",
                       help="Force CPU usage for BERT encoding instead of GPU (useful for performance comparison)")
    parser.add_argument("--debug", action="store_true",
                       help="Enable comprehensive debug output showing tokenization workflow steps")
    parser.add_argument("--debug-bert", action="store_true",
                       help="Show first few BERT embedding values for debugging (impacts performance)")
    parser.add_argument("--batch-size", type=int, default=128,
                       help="Batch size for batch processing (default: 128)")
    parser.add_argument("--enable-batch", action="store_true",
                       help="Enable batch processing for improved throughput")
    parser.add_argument("--timeout", type=int, default=30,
                       help="Timeout in seconds for waiting for packets (default: 30)")
    parser.add_argument("--disable-cache", action="store_true",
                       help="Disable tokenizer caches (Rust fast/Python and DPDK BPE)")
    parser.add_argument("--warmup", action="store_true",
                       help="Warm up tokenizer before measurement")
    parser.add_argument("--print-ids-only", action="store_true",
                       help="Print only token IDs to stdout (mutually exclusive with --debug)")
    # DPDK runtime tuning shortcuts
    parser.add_argument("--pin-core", type=int, default=None,
                        help="Pin DPDK tokenizer process to a CPU core (sets DPDK_PIN_CORE for child)")
    parser.add_argument("--rt-prio", type=int, default=None,
                        help="Run tokenizer with SCHED_FIFO at priority (sets DPDK_RT_PRIO; needs sudo/CAP_SYS_NICE)")
    parser.add_argument("--dpdk-log-level", default=None,
                        help="DPDK EAL log level (0..8) for child (overrides DPDK_LOG_LEVEL)")
    parser.add_argument("--disable-sudo", action="store_true",
                        help="Do NOT use sudo to launch the DPDK tokenizer (default is sudo for best perf)")
    parser.add_argument("--allow-non-isolated", action="store_true",
                        help="Allow running when pin core is not in isolcpus (reduced determinism)")
    parser.add_argument("--disable-one-core-limit", action="store_true",
                        help="Allow tokenizers to use multiple threads (default is single-core)")
    
    args = parser.parse_args()
    
    # Validate mutually exclusive options
    if args.print_ids_only and args.debug:
        parser.error("--print-ids-only and --debug cannot be used together")
        
    # Fast defaults: pick a sane core, RT prio, and log level when not provided
    cpu_count = os.cpu_count() or 1
    default_core = 0
    effective_pin_core = args.pin_core if args.pin_core is not None else default_core
    # Use higher RT priority by default for lower jitter
    effective_rt_prio = args.rt_prio if args.rt_prio is not None else 80
    effective_dpdk_log = args.dpdk_log_level if args.dpdk_log_level is not None else "3"

    if args.mode == "dpdk":
        run_dpdk_mode(
            tokenizer_type=args.tokenizer,
            max_packets=args.max_packets,
            model=args.model,
            encoder_type=args.encoder,
            verbose=args.verbose,
            force_cpu=args.cpu,
            debug_bert=args.debug_bert,
            debug=args.debug,
            batch_size=args.batch_size,
            enable_batch=args.enable_batch,
            timeout_seconds=args.timeout,
            print_ids_only=args.print_ids_only,
            pin_core=effective_pin_core,
            rt_prio=effective_rt_prio,
            dpdk_log_level=effective_dpdk_log,
            use_sudo=(not args.disable_sudo),
            allow_non_isolated=args.allow_non_isolated,
            disable_cache=args.disable_cache,
        )
    elif args.mode == "rust":
        run_python_rust_mode(
            model=args.model,
            enable_warmup=args.warmup,
            max_packets=args.max_packets,
            encoder_type=args.encoder,
            verbose=args.verbose,
            force_cpu=args.cpu,
            debug=args.debug,
            batch_size=args.batch_size,
            enable_batch=args.enable_batch,
            timeout_seconds=args.timeout,
            print_ids_only=args.print_ids_only,
            use_fast=True,
            pipeline_label="Rust",
            disable_cache=args.disable_cache,
        )
    elif args.mode == "python":
        run_python_rust_mode(
            model=args.model,
            enable_warmup=args.warmup,
            max_packets=args.max_packets,
            encoder_type=args.encoder,
            verbose=args.verbose,
            force_cpu=args.cpu,
            debug=args.debug,
            batch_size=args.batch_size,
            enable_batch=args.enable_batch,
            timeout_seconds=args.timeout,
            print_ids_only=args.print_ids_only,
            use_fast=False,
            pipeline_label="Python",
            disable_cache=args.disable_cache,
        )

if __name__ == "__main__":
    main()
