"""
Verify that EOS token positions match the saved offsets.
Checks 8 random 1M token chunks from buckets 5 and 6.
"""

import numpy as np
from pathlib import Path
import random

# Configuration
VOCAB = "32k"
OUTPUT_DIR = Path(f"pretokenized_data_{VOCAB}")
BUCKETS_TO_CHECK = [5, 6]
EOS_TOKEN_ID = 0  # <|endoftext|>
CHUNK_SIZE = 1_000_000  # 1M tokens per chunk
NUM_CHUNKS = 8


def verify_bucket_offsets(bucket_id: int):
    """
    Verify offsets for a bucket by checking EOS token positions.
    Returns (total_checked, mismatches) tuple.
    """
    print(f"\n{'='*80}")
    print(f"Verifying Bucket {bucket_id}")
    print(f"{'='*80}")

    # Load files
    tokens_path = OUTPUT_DIR / f"bucket_{bucket_id:02d}_train.bin"
    offsets_path = OUTPUT_DIR / f"bucket_{bucket_id:02d}_train_offsets.bin"

    if not tokens_path.exists() or not offsets_path.exists():
        print(f"ERROR: Files not found for bucket {bucket_id}")
        return 0, 0

    print(f"Loading tokens and offsets...")
    tokens = np.memmap(tokens_path, dtype=np.uint16, mode='r')
    offsets = np.memmap(offsets_path, dtype=np.uint64, mode='r')

    total_tokens = len(tokens)
    num_sequences = len(offsets) - 1

    print(f"  Total tokens: {total_tokens:,}")
    print(f"  Total sequences: {num_sequences:,}")
    print(f"  Offsets array size: {len(offsets):,}")

    # Select random 1M token chunks
    max_start = total_tokens - CHUNK_SIZE
    if max_start <= 0:
        print(f"ERROR: Not enough tokens for {CHUNK_SIZE:,} token chunks")
        return 0, 0

    chunk_starts = sorted(random.sample(range(0, max_start, CHUNK_SIZE),
                                       min(NUM_CHUNKS, max_start // CHUNK_SIZE)))

    print(f"\nChecking {len(chunk_starts)} random chunks of {CHUNK_SIZE:,} tokens each...")

    total_checked = 0
    total_mismatches = 0

    for chunk_idx, chunk_start in enumerate(chunk_starts):
        chunk_end = chunk_start + CHUNK_SIZE
        print(f"\n  Chunk {chunk_idx+1}: tokens [{chunk_start:,} - {chunk_end:,})")

        # Find all offsets that fall within this chunk
        # We want sequences that start or end within the chunk
        relevant_offsets_mask = (offsets >= chunk_start) & (offsets <= chunk_end)
        relevant_offset_indices = np.where(relevant_offsets_mask)[0]

        if len(relevant_offset_indices) == 0:
            print(f"    No sequence boundaries in this chunk")
            continue

        print(f"    Found {len(relevant_offset_indices)} sequence boundaries")

        # Check each sequence boundary
        chunk_mismatches = 0
        for offset_idx in relevant_offset_indices[:-1]:  # Skip last offset (it's just a boundary)
            seq_start = offsets[offset_idx]
            seq_end = offsets[offset_idx + 1]

            # EOS token should be at seq_end - 1 (last token of sequence)
            expected_eos_pos = seq_end - 1

            # Only check if the EOS position is within our loaded chunk
            if expected_eos_pos < chunk_start or expected_eos_pos >= chunk_end:
                continue

            total_checked += 1

            # Check if there's an EOS token at the expected position
            actual_token = tokens[expected_eos_pos]

            if actual_token != EOS_TOKEN_ID:
                chunk_mismatches += 1
                total_mismatches += 1
                print(f"      MISMATCH at sequence {offset_idx}:")
                print(f"        Expected EOS at position {expected_eos_pos:,}")
                print(f"        Found token ID: {actual_token} (expected {EOS_TOKEN_ID})")
                print(f"        Sequence range: [{seq_start:,} - {seq_end:,})")

        if chunk_mismatches == 0:
            print(f"    ✓ All {total_checked} boundaries verified correctly")
        else:
            print(f"    ✗ Found {chunk_mismatches} mismatches in this chunk")

    return total_checked, total_mismatches


def main():
    print("="*80)
    print("Verify Offsets: Cross-check EOS token positions")
    print("="*80)
    print(f"Checking {NUM_CHUNKS} random {CHUNK_SIZE:,}-token chunks per bucket")
    print(f"EOS token ID: {EOS_TOKEN_ID}")

    all_results = []

    for bucket_id in BUCKETS_TO_CHECK:
        checked, mismatches = verify_bucket_offsets(bucket_id)
        all_results.append({
            'bucket': bucket_id,
            'checked': checked,
            'mismatches': mismatches
        })

    # Summary
    print(f"\n{'='*80}")
    print(f"Summary:")
    print(f"{'='*80}")
    print(f"{'Bucket':<10} {'Checked':>12} {'Mismatches':>12} {'Status':>15}")
    print(f"{'-'*80}")

    total_checked = 0
    total_mismatches = 0

    for result in all_results:
        status = "✓ PASS" if result['mismatches'] == 0 else "✗ FAIL"
        print(f"{result['bucket']:<10} {result['checked']:>12,} {result['mismatches']:>12,} {status:>15}")
        total_checked += result['checked']
        total_mismatches += result['mismatches']

    print(f"{'-'*80}")
    overall_status = "✓ ALL PASS" if total_mismatches == 0 else "✗ FAILURES DETECTED"
    print(f"{'TOTAL':<10} {total_checked:>12,} {total_mismatches:>12,} {overall_status:>15}")

    if total_mismatches == 0:
        print(f"\n{'='*80}")
        print(f"SUCCESS: All offset boundaries match EOS token positions!")
        print(f"{'='*80}")
    else:
        print(f"\n{'='*80}")
        print(f"WARNING: Found {total_mismatches} mismatches across {total_checked} checks")
        print(f"This indicates potential data corruption or offset calculation errors")
        print(f"{'='*80}")


if __name__ == "__main__":
    random.seed(42)  # Reproducible random chunks
    main()
