#!/usr/bin/env python3
"""
Configurable N-packet test for tokenization performance measurement
"""

import socket
import time
import sys
import argparse
import os
import queue
from typing import List, Iterable, Tuple, Callable, Optional

# Verbosity control: default True for direct CLI usage. The experiment runner
# passes verbose=False to suppress non-essential prints.
SENDER_VERBOSE: bool = True

def get_simple_sentences():
    """Return 10 simple, reusable test sentences."""
    return [
        "Natural language processing with DPDK provides high-performance tokenization.",
        "Hello world!",
        "This is a simple test message.",
        "The quick brown fox jumps over the lazy dog.",
        "Machine learning models require efficient text processing.",
        "Deep learning enables complex pattern recognition.",
        "Transformers have become the backbone of modern NLP.",
        "Data preprocessing is crucial for model accuracy.",
        "GPU acceleration significantly improves training speed."
    ]

def get_code_snippets(num_snippets):
    """Load code snippets from GitHub Code dataset."""
    try:
        from datasets import load_dataset
        import pickle
        import os
        import random
        
        # Cache file path
        cache_dir = os.path.dirname(os.path.abspath(__file__))
        cache_file = os.path.join(cache_dir, ".code_cache.pkl")
        
        code_snippets = []
        
        # Try to load from cache first
        if os.path.exists(cache_file):
            try:
                with open(cache_file, 'rb') as f:
                    cached_data = pickle.load(f)
                    code_snippets = cached_data['code_snippets']
                    print(f"Loaded {len(code_snippets)} cached code snippets")
            except Exception as e:
                print(f"Cache read failed, will reload: {e}")
                code_snippets = []
        
        # If cache doesn't exist or failed to load, fetch from dataset
        if not code_snippets:
            print("Loading GitHub Code dataset (this may take a moment)...")
            
            # Use streaming to avoid downloading the entire dataset
            ds_code = load_dataset(
                "codeparrot/github-code",
                split="train",
                streaming=True,
                trust_remote_code=True
            )
            
            # Take only what we need using itertools
            import itertools
            max_examples = 10000
            code_examples = list(itertools.islice(ds_code, max_examples))
            print(f"Loaded {len(code_examples)} examples from GitHub Code dataset")
            
            # Extract code snippets
            for item in code_examples:
                if not item or "code" not in item or item["code"] is None:
                    continue
                
                code = item["code"]
                if not code or not isinstance(code, str):
                    continue
                
                # Split code into smaller chunks (by double newline or function definitions)
                # This helps create reasonable-sized packets
                code_lines = code.split('\n')
                current_chunk = []
                
                for line in code_lines:
                    current_chunk.append(line)
                    
                    # Create a chunk when we hit a natural boundary
                    if (len(current_chunk) >= 5 and len(current_chunk) <= 20) and \
                       (not line.strip() or line.strip().startswith('def ') or 
                        line.strip().startswith('function ') or line.strip().startswith('class ')):
                        chunk_text = '\n'.join(current_chunk).strip()
                        # Filter for reasonable length code snippets
                        if 50 <= len(chunk_text) <= 500:
                            code_snippets.append(chunk_text)
                        current_chunk = []
                
                # Add any remaining chunk
                if current_chunk and len(current_chunk) <= 20:
                    chunk_text = '\n'.join(current_chunk).strip()
                    if 50 <= len(chunk_text) <= 500:
                        code_snippets.append(chunk_text)
            
            # Cache the extracted code snippets
            print(f"Extracted {len(code_snippets)} code snippets, saving to cache...")
            try:
                with open(cache_file, 'wb') as f:
                    pickle.dump({'code_snippets': code_snippets}, f)
                print(f"Cache saved to {cache_file}")
            except Exception as e:
                print(f"Warning: Failed to save cache: {e}")
        
        # Shuffle with fixed seed for reproducibility
        random.seed(42)
        random.shuffle(code_snippets)
        
        # Return requested number of snippets
        if len(code_snippets) < num_snippets:
            print(f"Warning: Only found {len(code_snippets)} suitable code snippets")
            return code_snippets
        
        return code_snippets[:num_snippets]
        
    except ImportError:
        print("Error: 'datasets' library not installed. Install with: pip install datasets")
        sys.exit(1)
    except Exception as e:
        print(f"Error loading code dataset: {e}")
        sys.exit(1)

def get_multilingual_sentences(num_sentences):
    """Load sentences from multilingual MC4 dataset."""
    try:
        from datasets import load_dataset
        import pickle
        import os
        import random
        
        # Cache file path
        cache_dir = os.path.dirname(os.path.abspath(__file__))
        cache_file = os.path.join(cache_dir, ".multilingual_cache.pkl")
        
        sentences_by_lang = {}
        
        # Try to load from cache first
        if os.path.exists(cache_file):
            try:
                with open(cache_file, 'rb') as f:
                    cached_data = pickle.load(f)
                    sentences_by_lang = cached_data['sentences_by_lang']
                    
                    # Check if cache has proper language tags (not just 'unknown')
                    if len(sentences_by_lang) == 1 and 'unknown' in sentences_by_lang:
                        print("Cache has incorrect language tags, will regenerate...")
                        sentences_by_lang = {}
                    else:
                        total_sentences = sum(len(sents) for sents in sentences_by_lang.values())
                        print(f"Loaded {total_sentences} cached multilingual sentences from {len(sentences_by_lang)} languages")
                        # Show language breakdown
                        for lang in sorted(sentences_by_lang.keys()):
                            print(f"  {lang}: {len(sentences_by_lang[lang])} sentences")
            except Exception as e:
                print(f"Cache read failed, will reload: {e}")
                sentences_by_lang = {}
        
        # If cache doesn't exist or failed to load, fetch from dataset
        if not sentences_by_lang:
            print("Loading multilingual MC4 dataset (this may take a moment)...")
            languages = ["en", "fr", "de", "es", "zh"]
            
            # Use streaming to avoid loading the entire dataset
            # Load each language separately to ensure proper distribution
            all_texts_with_lang = []
            examples_per_lang = 2000  # 2000 examples per language = 10000 total for 5 languages
            
            for lang in languages:
                print(f"Loading {lang} examples...")
                try:
                    ds_lang = load_dataset(
                        "bertin-project/mc4-sampling",
                        languages=[lang],
                        split="train",
                        streaming=True,
                        trust_remote_code=True
                    )
                    
                    # Take examples for this language
                    import itertools
                    lang_texts = list(itertools.islice(ds_lang, examples_per_lang))
                    
                    # Add language tag to each item
                    for item in lang_texts:
                        item['language'] = lang  # Ensure language is set
                        all_texts_with_lang.append(item)
                    
                    print(f"  Loaded {len(lang_texts)} examples for {lang}")
                except Exception as e:
                    print(f"  Warning: Failed to load {lang}: {e}")
            
            texts_with_lang = all_texts_with_lang
            import random
            random.seed(42)  # Fixed seed for consistent dataset loading
            random.shuffle(texts_with_lang)  # Mix languages
            print(f"Loaded {len(texts_with_lang)} examples from MC4 multilingual dataset")
            
            # Debug: Check structure of first item
            if texts_with_lang:
                print(f"First item keys: {list(texts_with_lang[0].keys())}")
                # Don't try to slice the dict, only print it
                sample_str = str(texts_with_lang[0])
                if len(sample_str) > 500:
                    print(f"First item sample: {sample_str[:500]}...")
                else:
                    print(f"First item sample: {sample_str}")
            
            # Extract sentences from the texts, organized by language
            for i, item in enumerate(texts_with_lang):
                # Skip items with missing text
                if not item or "text" not in item or item["text"] is None:
                    continue
                    
                text = item["text"]
                # MC4 dataset uses 'language' field, not 'lang'
                lang = item.get("language", item.get("lang", "unknown"))
                
                # If still unknown, try to infer from the dataset structure
                if lang == "unknown" and "language" not in item and "lang" not in item:
                    # The MC4 multilingual dataset should include language info
                    print(f"Warning: Could not determine language for item. Keys available: {list(item.keys())}")
                
                if lang not in sentences_by_lang:
                    sentences_by_lang[lang] = []
                
                # Skip empty texts
                if not text or not isinstance(text, str):
                    continue
                
                # Split on period followed by space or newline
                # For Chinese, also split on common Chinese punctuation
                if lang == "zh":
                    import re
                    text_sentences = re.split(r'[。！？]', text)
                else:
                    text_sentences = text.replace('\n', ' ').split('. ')
                
                for sent in text_sentences:
                    sent = sent.strip()
                    # Filter for reasonable length sentences (adjust for Chinese)
                    if lang == "zh":
                        if 10 <= len(sent) <= 100:  # Character count for Chinese
                            sentences_by_lang[lang].append(sent + "。")
                    else:
                        if 10 <= len(sent.split()) <= 50 and len(sent) > 20:
                            sentences_by_lang[lang].append(sent + ".")
            
            # Cache the extracted sentences
            total_sentences = sum(len(sents) for sents in sentences_by_lang.values())
            print(f"Extracted {total_sentences} sentences across {len(sentences_by_lang)} languages")
            for lang, sents in sentences_by_lang.items():
                print(f"  {lang}: {len(sents)} sentences")
            
            print("Saving to cache...")
            try:
                with open(cache_file, 'wb') as f:
                    pickle.dump({'sentences_by_lang': sentences_by_lang}, f)
                print(f"Cache saved to {cache_file}")
            except Exception as e:
                print(f"Warning: Failed to save cache: {e}")
        
        # Mix sentences from different languages randomly
        all_sentences = []
        for lang, sents in sentences_by_lang.items():
            # Tag each sentence with its language for debugging
            all_sentences.extend([(sent, lang) for sent in sents])
        
        # Shuffle to mix languages with fixed seed for reproducibility
        random.seed(42)  # Fixed seed for consistent results
        random.shuffle(all_sentences)
        
        # Return requested number of sentences (just the text, not the language tag)
        if len(all_sentences) < num_sentences:
            print(f"Warning: Only found {len(all_sentences)} suitable sentences")
            return [sent for sent, lang in all_sentences]
        
        selected = all_sentences[:num_sentences]
        # Print language distribution of selected sentences
        lang_counts = {}
        for sent, lang in selected:
            lang_counts[lang] = lang_counts.get(lang, 0) + 1
        print(f"Selected {num_sentences} sentences - language distribution: {lang_counts}")
        
        return [sent for sent, lang in selected]
        
    except ImportError:
        print("Error: 'datasets' library not installed. Install with: pip install datasets")
        sys.exit(1)
    except Exception as e:
        print(f"Error loading multilingual dataset: {e}")
        sys.exit(1)

def get_openwebtext_sentences(num_sentences):
    """Load sentences from OpenWebText dataset."""
    try:
        from datasets import load_dataset
        import pickle
        import os
        
        # Cache file path
        cache_dir = os.path.dirname(os.path.abspath(__file__))
        cache_file = os.path.join(cache_dir, ".openwebtext_cache.pkl")
        
        sentences = []
        
        # Try to load from cache first
        if os.path.exists(cache_file):
            try:
                with open(cache_file, 'rb') as f:
                    cached_data = pickle.load(f)
                    sentences = cached_data['sentences']
                    if SENDER_VERBOSE:
                        print(f"Loaded {len(sentences)} cached sentences from OpenWebText")
            except Exception as e:
                print(f"Cache read failed, will reload: {e}")
                sentences = []
        
        # If cache doesn't exist or failed to load, fetch from dataset
        if not sentences:
            print("Loading OpenWebText dataset (this may take a moment)...")
            # Use streaming to avoid loading the entire dataset
            ds = load_dataset("openwebtext", split="train", streaming=True, trust_remote_code=True)
            
            # Take only what we need using itertools
            import itertools
            max_examples = 10000
            texts = list(itertools.islice(ds, max_examples))  # Take first 10000 examples max
            texts = [item["text"] for item in texts]
            print(f"Loaded {len(texts)} examples from OpenWebText (streaming mode, max {max_examples})")
            
            # Extract sentences from the texts
            for text in texts:
                # Split on period followed by space or newline
                text_sentences = text.replace('\n', ' ').split('. ')
                for sent in text_sentences:
                    sent = sent.strip()
                    # Filter for reasonable length sentences
                    if 10 <= len(sent.split()) <= 50 and len(sent) > 20:
                        sentences.append(sent + ".")
            
            # Cache the extracted sentences
            print(f"Extracted {len(sentences)} sentences, saving to cache...")
            try:
                with open(cache_file, 'wb') as f:
                    pickle.dump({'sentences': sentences}, f)
                print(f"Cache saved to {cache_file}")
            except Exception as e:
                print(f"Warning: Failed to save cache: {e}")
        
        # Return requested number of sentences
        if len(sentences) < num_sentences:
            print(f"Warning: Only found {len(sentences)} suitable sentences")
            return sentences
        
        return sentences[:num_sentences]
        
    except ImportError:
        print("Error: 'datasets' library not installed. Install with: pip install datasets")
        sys.exit(1)
    except Exception as e:
        print(f"Error loading OpenWebText: {e}")
        sys.exit(1)

def _project_root() -> str:
    here = os.path.dirname(os.path.abspath(__file__))
    # tests/python/send_packets -> repo root is three levels up
    return os.path.abspath(os.path.join(here, "..", "..", ".."))


def _load_text_pool(mode: str, count: int) -> List[str]:
    if mode == 'code':
        return get_code_snippets(count)
    if mode == 'multilingual':
        return get_multilingual_sentences(count)
    if mode == 'openwebtext':
        return get_openwebtext_sentences(count)
    # simple
    base = get_simple_sentences()
    out = []
    for i in range(count):
        s = base[i % len(base)]
        if i >= len(base):
            s = f"[Packet {i+1}] {s}"
        out.append(s)
    return out


def _chunk_and_send(sock: socket.socket, payload_utf8: bytes, port: int, delay_ms: int = 0,
                    use_msg_id_header: bool = False, message_id: int = 0,
                    dup_chunks: int = 0):
    MAX_CHUNK = 1200  # must match server limit
    total = (len(payload_utf8) + MAX_CHUNK - 1) // MAX_CHUNK
    if total == 0:
        total = 1
        chunks = [b""]
    else:
        chunks = [payload_utf8[i*MAX_CHUNK:(i+1)*MAX_CHUNK] for i in range(total)]

    for seq, chunk in enumerate(chunks):
        if use_msg_id_header:
            header = b'TOKN' + seq.to_bytes(4, 'big') + total.to_bytes(4, 'big') + message_id.to_bytes(4, 'big')
        else:
            header = seq.to_bytes(4, 'big') + total.to_bytes(4, 'big')
        # send primary copy
        sock.sendto(header + chunk, ("127.0.0.1", port))
        # optionally send a duplicate copy to reduce drop risk
        for _ in range(max(0, dup_chunks)):
            sock.sendto(header + chunk, ("127.0.0.1", port))
        if delay_ms:
            time.sleep(delay_ms / 1000.0)


def _tokenizer_load_gpt2():
    # Use Rust-backed tokenizers via local tokenizer.json (no network)
    try:
        from tokenizers import Tokenizer
    except Exception as e:
        print("Error: tokenizers (Rust) package is required. Install with: pip install tokenizers")
        sys.exit(1)

    tk_path = os.path.join(_project_root(), 'tokenizer_data', 'gpt2', 'tokenizer.json')
    if not os.path.exists(tk_path):
        print(f"Error: tokenizer config not found at {tk_path}")
        sys.exit(1)
    tok = Tokenizer.from_file(tk_path)
    # Ensure special tokens are not added automatically for GPT-2 parity
    try:
        tok.enable_padding(False)
    except Exception:
        pass
    return tok


def _yield_packet_texts_by_tokens(num_packets: int,
                                  tokens_per_packet: int,
                                  mode: str,
                                  lowercase: bool,
                                  max_seq_len: int,
                                  delay_ms: int = 0,
                                  dataset_offset: int = 0) -> Iterable[Tuple[int, str]]:
    """Yield (idx, text) for each packet where text decodes to exactly tokens_per_packet IDs.
    Uses Rust tokenizers with local GPT-2 tokenizer.json.
    """
    if tokens_per_packet <= 0:
        raise ValueError("tokens_per_packet must be positive")
    if tokens_per_packet > max_seq_len:
        raise ValueError(f"Requested tokens_per_packet ({tokens_per_packet}) exceeds max_seq_len ({max_seq_len})")

    tok = _tokenizer_load_gpt2()

    oversample = 20
    # Ensure the pool accounts for the starting offset so each trial can use distinct texts
    pool_size = max((dataset_offset + num_packets) * oversample, 100)
    pool = _load_text_pool(mode, pool_size)
    pool_idx = dataset_offset % len(pool) if pool else 0
    pending: List[int] = []

    def encode_ids(text: str) -> List[int]:
        if lowercase:
            text = text.lower()
        return tok.encode(text, add_special_tokens=False).ids

    def decode_ids(ids: List[int]) -> str:
        return tok.decode(ids)

    for pkt_idx in range(1, num_packets + 1):
        while len(pending) < tokens_per_packet:
            if pool_idx >= len(pool):
                # Refill pool if exhausted
                pool = _load_text_pool(mode, max(num_packets * oversample, 100))
                pool_idx = 0
                if not pool:
                    raise RuntimeError("Failed to load more source texts for packing")
            ids = encode_ids(pool[pool_idx])
            pool_idx += 1
            pending.extend(ids)

        take = pending[:tokens_per_packet]
        pending = pending[tokens_per_packet:]

        # Roundtrip check for sanity
        text = decode_ids(take)
        if encode_ids(text) != take:
            # In rare cases, decoding normalization may change bytes; fix by re-splitting
            rt = encode_ids(text)
            if len(rt) >= tokens_per_packet:
                take = rt[:tokens_per_packet]
                text = decode_ids(take)
            else:
                # Fallback: add more IDs to meet target
                while len(rt) < tokens_per_packet and pending:
                    rt.append(pending.pop(0))
                take = rt[:tokens_per_packet]
                text = decode_ids(take)

        yield pkt_idx, text


def send_test_packets(target_port=6000,
                      num_packets=5,
                      lowercase=False,
                      use_openwebtext=False,
                      use_multilingual=False,
                      use_code=False,
                      tokens_per_packet: int = 0,
                      max_seq_len: int = 2048,
                      delay_ms: int = 0,
                      dataset_offset: int = 0,
                      use_msg_id_header: bool = True,
                      dup_chunks: int = 1,
                      on_send: Optional[Callable[[int, float], None]] = None,
                      packet_delay_ms: int = 0,
                      permit_queue: Optional["queue.Queue"] = None,
                      verbose: bool = True):
    """Send specified number of test packets.

    If tokens_per_packet > 0, builds each packet by token count using Rust tokenizers (GPT-2).
    Uses UDP chunking compatible with the VM server (<=1200 bytes per chunk).
    """

    mode = 'simple'
    if use_code:
        mode = 'code'
    elif use_multilingual:
        mode = 'multilingual'
    elif use_openwebtext:
        mode = 'openwebtext'

    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    # Set module-wide verbosity for helper prints
    global SENDER_VERBOSE
    SENDER_VERBOSE = bool(verbose)

    if SENDER_VERBOSE:
        print(f"Sending {num_packets} test packets to port {target_port}")
        print(f"Mode: {mode.capitalize()} | lowercase={lowercase}")
        if tokens_per_packet > 0:
            print(f"Token-controlled mode: {tokens_per_packet} tokens/packet (max_seq_len={max_seq_len})")
        if delay_ms:
            print(f"Inter-chunk delay: {delay_ms} ms")
        print()

    message_id = 1 + (dataset_offset or 0)
    if tokens_per_packet > 0:
        # Token-controlled path
        for i, text in _yield_packet_texts_by_tokens(num_packets, tokens_per_packet, mode, lowercase, max_seq_len, delay_ms, dataset_offset):
            # Pacing: wait for permit if provided (sequential send-on-complete)
            if permit_queue is not None:
                try:
                    permit_queue.get(timeout=300)
                except Exception:
                    pass

            # Send packet (no automatic retries - receiver will handle timeouts)
            payload = text.encode('utf-8')
            if on_send:
                on_send(message_id, time.perf_counter())
            _chunk_and_send(sock, payload, target_port, delay_ms,
                            use_msg_id_header=use_msg_id_header, message_id=message_id,
                            dup_chunks=dup_chunks)
            preview = text[:80] + "..." if len(text) > 80 else text
            if SENDER_VERBOSE:
                print(f"Sent packet {i}: '{preview}' (tokens={tokens_per_packet})")

            # Optional fixed delay (disabled by default)
            if packet_delay_ms > 0:
                time.sleep(packet_delay_ms / 1000.0)
            message_id += 1
    else:
        # Text-only path (legacy behavior) with chunking
        if mode == 'code':
            sentences = get_code_snippets(num_packets)
            if len(sentences) < num_packets:
                print(f"Warning: Only found {len(sentences)} suitable code snippets")
                num_packets = len(sentences)
        elif mode == 'multilingual':
            sentences = get_multilingual_sentences(num_packets)
            if len(sentences) < num_packets:
                print(f"Warning: Only found {len(sentences)} suitable sentences in multilingual dataset")
                num_packets = len(sentences)
        elif mode == 'openwebtext':
            sentences = get_openwebtext_sentences(num_packets)
            if len(sentences) < num_packets:
                print(f"Warning: Only found {len(sentences)} suitable sentences in OpenWebText")
                num_packets = len(sentences)
        else:
            base_sentences = get_simple_sentences()
            sentences = []
            for i in range(num_packets):
                s = base_sentences[i % len(base_sentences)]
                if i >= len(base_sentences):
                    s = f"[Packet {i+1}] {s}"
                sentences.append(s)

        if lowercase:
            sentences = [s.lower() for s in sentences]
            if SENDER_VERBOSE:
                print("Using lowercase-only text")

        for i, msg in enumerate(sentences, 1):
            if permit_queue is not None:
                try:
                    permit_queue.get(timeout=300)
                except Exception:
                    pass
            if on_send:
                on_send(message_id, time.perf_counter())
            _chunk_and_send(sock, msg.encode('utf-8'), target_port, delay_ms,
                            use_msg_id_header=use_msg_id_header, message_id=message_id,
                            dup_chunks=dup_chunks)
            preview = msg[:80] + "..." if len(msg) > 80 else msg
            if SENDER_VERBOSE:
                print(f"Sent packet {i}: '{preview}'")
            if packet_delay_ms > 0:
                time.sleep(packet_delay_ms / 1000.0)
            message_id += 1

    sock.close()
    if SENDER_VERBOSE:
        print(f"\nAll {num_packets} packets sent successfully!")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Send N test packets to DPDK tokenizer',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Send 5 simple test packets (default)
  python send_n_packets.py
  
  # Send 20 packets using OpenWebText
  python send_n_packets.py -n 20 --openwebtext
  
  # Send 10 multilingual packets (mixed languages)
  python send_n_packets.py -n 10 --multilingual
  
  # Send 5 code snippets
  python send_n_packets.py -n 5 --code
  
  # Send 100 lowercase packets to port 7000
  python send_n_packets.py 7000 -n 100 --lowercase
        """
    )
    
    parser.add_argument('port', type=int, nargs='?', default=6000, 
                        help='Target port (default: 6000)')
    parser.add_argument('-n', '--num-packets', type=int, default=5,
                        help='Number of packets to send (default: 5)')
    parser.add_argument('--lowercase', action='store_true', 
                        help='Convert text to lowercase')
    # Preferred unified dataset selector
    parser.add_argument('--dataset', choices=['simple', 'openwebtext', 'multilingual', 'code'], default=None,
                        help='Dataset to use (simple|openwebtext|multilingual|code). Overrides legacy flags.')
    # Legacy flags (kept for backward-compat)
    parser.add_argument('--openwebtext', action='store_true',
                        help='[Deprecated] Use OpenWebText dataset instead of simple sentences')
    parser.add_argument('--multilingual', action='store_true',
                        help='[Deprecated] Use multilingual MC4 dataset (en, fr, de, es, zh mixed)')
    parser.add_argument('--code', action='store_true',
                        help='[Deprecated] Use GitHub Code dataset for code snippets')
    parser.add_argument('--tokens-per-packet', type=int, default=0,
                        help='If > 0, build each packet to contain exactly this many GPT-2 tokens using Rust tokenizers')
    parser.add_argument('--max-seq-len', type=int, default=2048,
                        help='Maximum tokens supported by the receiver (default: 2048)')
    parser.add_argument('--delay-ms', type=int, default=0,
                        help='Optional inter-chunk delay in milliseconds (helps pacing for large payloads)')
    parser.add_argument('--packet-delay-ms', type=int, default=0,
                        help='Optional inter-packet delay in milliseconds (sleep after each message)')
    parser.add_argument('--no-msg-id-header', action='store_true',
                        help='Disable extended header with message-id (default: enabled)')
    
    args = parser.parse_args()
    
    # Validate packet count
    if args.num_packets <= 0:
        print("Error: Number of packets must be positive")
        sys.exit(1)
    
    # Resolve dataset mode
    if args.dataset:
        dataset_mode = args.dataset
    else:
        dataset_options = sum([args.openwebtext, args.multilingual, args.code])
        if dataset_options > 1:
            print("Error: Can only use one dataset option at a time (--openwebtext, --multilingual, or --code)")
            sys.exit(1)
        if args.code:
            dataset_mode = 'code'
        elif args.multilingual:
            dataset_mode = 'multilingual'
        elif args.openwebtext:
            dataset_mode = 'openwebtext'
        else:
            dataset_mode = 'simple'
    
    # Validate packet count
    if args.tokens_per_packet < 0:
        print("Error: --tokens-per-packet must be >= 0")
        sys.exit(1)

    send_test_packets(
        target_port=args.port,
        num_packets=args.num_packets,
        lowercase=args.lowercase,
        use_openwebtext=(dataset_mode=='openwebtext'),
        use_multilingual=(dataset_mode=='multilingual'),
        use_code=(dataset_mode=='code'),
        tokens_per_packet=args.tokens_per_packet,
        max_seq_len=args.max_seq_len,
        delay_ms=args.delay_ms,
        use_msg_id_header=(not args.no_msg_id_header),
        packet_delay_ms=args.packet_delay_ms,
    )
