"""
Load science_tech 2e15 bucket, handling schema inconsistencies gracefully.
Uses streaming mode to skip problematic shards.
"""

from datasets import load_dataset
from pathlib import Path
import json
from tqdm import tqdm

DATASET_NAME = "allenai/dolma3_longmino_pool"
TOPIC_PATTERN = "science_tech"
BUCKET = "2e15"
OUTPUT_PATH = f"science_tech_{BUCKET}_dataset"

print("="*80)
print(f"Loading {TOPIC_PATTERN} - {BUCKET} (Robust Mode)")
print("="*80)

data_files_pattern = f"data/*{TOPIC_PATTERN}*-{BUCKET}/*.jsonl.zst"
print(f"File pattern: {data_files_pattern}")

# Load in streaming mode to handle errors gracefully
print("\nLoading dataset in streaming mode...")
ds_stream = load_dataset(
    DATASET_NAME,
    data_files=data_files_pattern,
    split="train",
    streaming=True
)

# Collect all examples, skipping errors
print("Streaming and collecting documents (skipping errors)...")
all_examples = []
errors = 0

for i, example in enumerate(tqdm(ds_stream, desc="Loading", unit=" docs")):
    try:
        # Validate that we have the text field
        if 'text' not in example:
            errors += 1
            continue

        # Only keep the text field (discard problematic metadata)
        all_examples.append({'text': example['text']})

        # Periodic status
        if (i + 1) % 10000 == 0:
            print(f"  Loaded {len(all_examples):,} documents, {errors} errors")

    except Exception as e:
        errors += 1
        if errors <= 10:  # Only print first few errors
            print(f"  Error on document {i}: {e}")
        continue

print(f"\n✓ Streaming complete!")
print(f"  Successfully loaded: {len(all_examples):,} documents")
print(f"  Errors/skipped: {errors:,}")

# Convert to HF dataset
print(f"\nConverting to HuggingFace dataset...")
from datasets import Dataset
ds = Dataset.from_list(all_examples)

print(f"  Total examples: {len(ds):,}")

# Show stats
print(f"\nDataset info:")
print(f"  Features: {ds.features}")

if len(ds) > 0:
    example = ds[0]
    text_len = len(example['text'])
    print(f"\n  First example:")
    print(f"    Text length: {text_len:,} chars")
    print(f"    Preview: {example['text'][:200]}...")

    # Estimate tokens
    print(f"\n  Estimating total tokens...")
    sample_size = min(1000, len(ds))
    sample = ds.select(range(sample_size))

    total_chars = sum(len(doc['text']) for doc in sample)
    avg_chars = total_chars / sample_size
    chars_per_token = 4
    avg_tokens_per_doc = avg_chars / chars_per_token
    total_estimated_tokens = (len(ds) * avg_tokens_per_doc) / 1e9

    print(f"    Sampled {sample_size:,} documents")
    print(f"    Average chars per doc: {avg_chars:,.0f}")
    print(f"    Average tokens per doc: {avg_tokens_per_doc:,.0f}")
    print(f"    Estimated total tokens: {total_estimated_tokens:.1f}B")

# Save
print(f"\nSaving dataset to: {OUTPUT_PATH}")
ds.save_to_disk(OUTPUT_PATH)
print(f"✓ Saved!")

print(f"\n{'='*80}")
print("Done! Load with:")
print(f"  from datasets import load_from_disk")
print(f"  ds = load_from_disk('{OUTPUT_PATH}')")
print("="*80)
