"""
Verify tokenization results by examining the first 100k tokens of a bucket.
Shows first 100 lines and all lines containing special tokens.
"""

from tokenizers import Tokenizer
from pathlib import Path
import numpy as np

# Configuration
VOCAB = "32k"
TOKENIZER_PATH = Path(f"tokenizer_{VOCAB}.json")
OUTPUT_DIR = Path(f"pretokenized_data_{VOCAB}")
BUCKET_TO_VERIFY = 5

# Special tokens to highlight
SPECIAL_TOKENS = ["<reponame>", "<filename>", "<|endoftext|>", "<gh_stars>",
                  "<fim_prefix>", "<fim_middle>", "<fim_suffix>", "<fim_pad>"]


def main():
    print("="*80)
    print(f"Verifying Tokenization: Bucket {BUCKET_TO_VERIFY}")
    print("="*80)

    # Load tokenizer
    print(f"\nLoading tokenizer: {TOKENIZER_PATH}")
    tokenizer = Tokenizer.from_file(str(TOKENIZER_PATH))
    print(f"  Vocab size: {tokenizer.get_vocab_size():,}")

    # Check if special tokens are in vocab
    print(f"\n  Checking special tokens in vocab:")
    for token in SPECIAL_TOKENS:
        token_id = tokenizer.token_to_id(token)
        if token_id is not None:
            print(f"    {token:20s} -> ID {token_id}")
        else:
            print(f"    {token:20s} -> NOT FOUND")

    # Show some example token IDs
    test_text = "<reponame>test/repo<filename>test.py\ncode<|endoftext|>"
    test_tokens = tokenizer.encode(test_text).ids
    print(f"\n  Test encoding: '{test_text[:50]}...'")
    print(f"    Token IDs: {test_tokens[:20]}")
    print(f"    Decoded: '{tokenizer.decode(test_tokens, skip_special_tokens=False)[:50]}...'")
    print()

    # Load tokens file
    tokens_path = OUTPUT_DIR / f"bucket_{BUCKET_TO_VERIFY:02d}_train.bin"
    if not tokens_path.exists():
        print(f"\nERROR: Tokens file not found: {tokens_path}")
        return

    print(f"\nLoading tokens: {tokens_path}")
    tokens = np.memmap(tokens_path, dtype=np.uint16, mode='r')
    print(f"  Total tokens: {len(tokens):,}")

    # Read first 100k tokens
    num_tokens = min(100_000, len(tokens))
    first_tokens = tokens[:num_tokens]
    print(f"  Reading first {num_tokens:,} tokens")

    # Show first 50 raw token IDs
    print(f"\n  First 50 token IDs: {first_tokens[:50].tolist()}")

    # Check for special token IDs in the data
    special_token_ids = {tokenizer.token_to_id(t): t for t in SPECIAL_TOKENS if tokenizer.token_to_id(t) is not None}
    print(f"\n  Special token IDs to look for: {special_token_ids}")

    found_special = []
    for i, token_id in enumerate(first_tokens[:1000]):
        if token_id in special_token_ids:
            found_special.append((i, token_id, special_token_ids[token_id]))

    print(f"  Special tokens found in first 1000: {found_special[:10]}")

    # Detokenize
    print(f"\nDetokenizing...")
    text = tokenizer.decode(first_tokens.tolist(), skip_special_tokens=False)
    lines = text.split('\n')
    print(f"  Generated {len(lines):,} lines")

    # Show first 100 lines
    print(f"\n{'='*80}")
    print(f"First 100 Lines:")
    print(f"{'='*80}")
    for i, line in enumerate(lines[:100]):
        # Truncate very long lines for display
        display_line = line if len(line) <= 120 else line[:117] + "..."
        print(f"{i+1:3d}: {display_line}")

    # Show all lines with special tokens (repo structure)
    print(f"\n{'='*80}")
    print(f"Lines with Special Tokens (Repo Structure):")
    print(f"{'='*80}")

    special_lines = []
    for i, line in enumerate(lines):
        for special_token in SPECIAL_TOKENS:
            if special_token in line:
                special_lines.append((i+1, line))
                break

    print(f"Found {len(special_lines)} lines with special tokens\n")

    for line_num, line in special_lines:
        # Truncate very long lines
        display_line = line if len(line) <= 120 else line[:117] + "..."
        print(f"{line_num:4d}: {display_line}")

    # Count repos in this sample
    repo_count = text.count("<reponame>")
    file_count = text.count("<filename>")
    eos_count = text.count("<|endoftext|>")

    print(f"\n{'='*80}")
    print(f"Statistics in first {num_tokens:,} tokens:")
    print(f"{'='*80}")
    print(f"  Repos (<reponame>):       {repo_count:,}")
    print(f"  Files (<filename>):       {file_count:,}")
    print(f"  Sequences (<|endoftext|>): {eos_count:,}")
    print(f"  Avg files per repo:       {file_count/repo_count:.1f}" if repo_count > 0 else "")
    print(f"  Avg tokens per repo:      {num_tokens/repo_count:,.0f}" if repo_count > 0 else "")

    # Show first 2 complete repos
    print(f"\n{'='*80}")
    print(f"First 2 Complete Repos (Structure Only):")
    print(f"{'='*80}")

    repo_texts = text.split("<|endoftext|>")
    for i, repo_text in enumerate(repo_texts[:2]):
        if not repo_text.strip():
            continue

        print(f"\n--- Repo {i+1} ---")

        # Extract and show structure
        repo_lines = repo_text.split('\n')
        for line in repo_lines:
            # Only show lines with special tokens or first few chars of content
            if any(token in line for token in SPECIAL_TOKENS):
                print(f"  {line}")
            elif len(line.strip()) > 0 and repo_lines.index(line) < 5:
                # Show first line of each file content (truncated)
                preview = line[:100] + "..." if len(line) > 100 else line
                print(f"    [content] {preview}")
                break  # Only show first content line per file


if __name__ == "__main__":
    main()
