"""
Load science_tech topic, 2e15 bucket (32k-64k tokens) from dolma3_longmino_pool.
Save as HuggingFace dataset for easy access.
"""

from datasets import load_dataset
import re

DATASET_NAME = "allenai/dolma3_longmino_pool"
TOPIC_PATTERN = "science_tech"
BUCKET = "2e15"  # 32k-64k tokens

OUTPUT_PATH = f"science_tech_{BUCKET}_dataset"

print("="*80)
print(f"Loading {TOPIC_PATTERN} - {BUCKET}")
print("="*80)

# Load with data_files pattern to filter for specific shards
# Pattern: data/*science_tech*-2e15/*.jsonl.zst
data_files_pattern = f"data/*{TOPIC_PATTERN}*-{BUCKET}/*.jsonl.zst"

print(f"\nFile pattern: {data_files_pattern}")
print("Loading dataset (this may take a while)...")

try:
    # Use multiple workers to parallelize download and processing
    import os
    num_workers = os.cpu_count() // 2  # Use half the cores

    print(f"Using {num_workers} parallel workers for download...")

    ds = load_dataset(
        DATASET_NAME,
        data_files=data_files_pattern,
        split="train",
        num_proc=num_workers
    )

    print(f"\n✓ Dataset loaded successfully!")
    print(f"  Total examples: {len(ds):,}")

    # Show dataset info
    print(f"\nDataset features:")
    print(f"  {ds.features}")

    # Show first example
    print(f"\nFirst example:")
    example = ds[0]
    for key, value in example.items():
        if isinstance(value, str):
            preview = value[:200] if len(value) > 200 else value
            print(f"  {key}: {preview}...")
            print(f"    (length: {len(value):,} chars)")
        else:
            print(f"  {key}: {value}")

    # Estimate tokens
    # If there's a 'text' field, estimate from character count
    if 'text' in example:
        # Sample to estimate average doc length
        print(f"\nSampling to estimate total tokens...")
        sample_size = min(1000, len(ds))
        sample = ds.select(range(sample_size))

        total_chars = sum(len(doc['text']) for doc in sample)
        avg_chars = total_chars / sample_size

        # Rough estimate: ~4 chars per token
        chars_per_token = 4
        avg_tokens_per_doc = avg_chars / chars_per_token
        total_estimated_tokens = (len(ds) * avg_tokens_per_doc) / 1e9

        print(f"  Sampled {sample_size:,} documents")
        print(f"  Average chars per doc: {avg_chars:,.0f}")
        print(f"  Average tokens per doc: {avg_tokens_per_doc:,.0f}")
        print(f"  Estimated total tokens: {total_estimated_tokens:.1f}B")

    # Save to disk
    print(f"\nSaving dataset to: {OUTPUT_PATH}")
    ds.save_to_disk(OUTPUT_PATH)
    print(f"✓ Saved!")

    print(f"\n{'='*80}")
    print("Done! Load with:")
    print(f"  from datasets import load_from_disk")
    print(f"  ds = load_from_disk('{OUTPUT_PATH}')")
    print("="*80)

except Exception as e:
    print(f"\n✗ Error loading dataset: {e}")
    import traceback
    traceback.print_exc()

    print("\n" + "="*80)
    print("Troubleshooting:")
    print("="*80)
    print("If you get a JSON parse error, you can try:")
    print("1. Load with streaming=True to process incrementally")
    print("2. Use ignore_verifications=True")
    print("3. Download shards directly and parse manually")
