"""
Download science_tech 2e15 shards directly and parse manually,
bypassing HuggingFace's schema validation.
"""

from huggingface_hub import hf_hub_download
from pathlib import Path
import zstandard as zstd
import json
from tqdm import tqdm
from datasets import Dataset

DATASET_NAME = "allenai/dolma3_longmino_pool"
SOURCE = "olmocr_science_pdfs"
TOPIC = "science_tech"
BUCKET = "2e16"
OUTPUT_PATH = f"{TOPIC}_{BUCKET}_dataset"
LOCAL_CACHE = Path("./science_tech_shards")

# Get list of files to download
print("="*80)
print(f"Downloading and Parsing: {SOURCE}-{TOPIC}-{BUCKET}")
print("="*80)

from huggingface_hub import HfFileSystem
fs = HfFileSystem()

print("\nListing shard files...")
repo_path = f"datasets/{DATASET_NAME}/data/{SOURCE}-{TOPIC}-{BUCKET}"
try:
    shard_files = fs.ls(repo_path, detail=False)
    shard_files = [f.split('/')[-1] for f in shard_files if f.endswith('.jsonl.zst')]
    print(f"Found {len(shard_files)} shard files")
except Exception as e:
    print(f"Error listing files: {e}")
    exit(1)

# Process shards in parallel
from concurrent.futures import ProcessPoolExecutor, as_completed
import os

LOCAL_CACHE.mkdir(exist_ok=True)

def process_shard(shard_file):
    """Download and parse a single shard. Returns (documents, errors, lines)."""
    try:
        # Download shard
        local_path = hf_hub_download(
            repo_id=DATASET_NAME,
            filename=f"data/{SOURCE}-{TOPIC}-{BUCKET}/{shard_file}",
            repo_type="dataset",
            local_dir=LOCAL_CACHE,
            local_dir_use_symlinks=False
        )

        # Decompress and parse line by line
        documents = []
        errors = 0
        lines = 0

        with open(local_path, 'rb') as f_in:
            dctx = zstd.ZstdDecompressor()
            with dctx.stream_reader(f_in) as reader:
                # Read decompressed data
                text_stream = reader.read().decode('utf-8')

                # Parse line by line
                for line in text_stream.split('\n'):
                    if not line.strip():
                        continue

                    lines += 1

                    try:
                        doc = json.loads(line)

                        # Extract only the text field
                        if 'text' in doc and doc['text']:
                            documents.append({'text': doc['text']})
                        else:
                            errors += 1

                    except json.JSONDecodeError:
                        errors += 1

        return documents, errors, lines

    except Exception as e:
        return [], 1, 0

# Process in parallel
num_workers = os.cpu_count() // 2
print(f"\nProcessing {len(shard_files)} shards with {num_workers} parallel workers...")

all_documents = []
total_errors = 0
total_lines = 0

with ProcessPoolExecutor(max_workers=num_workers) as executor:
    # Submit all shards
    futures = {executor.submit(process_shard, shard): shard for shard in shard_files}

    # Collect results as they complete
    for future in tqdm(as_completed(futures), total=len(shard_files), desc="Processing shards"):
        try:
            documents, errors, lines = future.result()
            all_documents.extend(documents)
            total_errors += errors
            total_lines += lines

            # Progress update
            if len(all_documents) % 50000 == 0:
                print(f"  Loaded {len(all_documents):,} documents so far...")

        except Exception as e:
            print(f"  Error: {e}")
            total_errors += 1

print(f"\n✓ Processing complete!")
print(f"  Total lines processed: {total_lines:,}")
print(f"  Successfully loaded: {len(all_documents):,} documents")
print(f"  Errors/skipped: {total_errors:,}")

# Convert to HF dataset
print(f"\nConverting to HuggingFace dataset...")
ds = Dataset.from_list(all_documents)

print(f"  Total examples: {len(ds):,}")

# Show stats
if len(ds) > 0:
    example = ds[0]
    print(f"\n  First example:")
    print(f"    Text length: {len(example['text']):,} chars")
    print(f"    Preview: {example['text'][:200]}...")

    # Estimate tokens
    print(f"\n  Estimating total tokens...")
    sample_size = min(1000, len(ds))
    sample = ds.select(range(sample_size))

    total_chars = sum(len(doc['text']) for doc in sample)
    avg_chars = total_chars / sample_size
    chars_per_token = 4
    avg_tokens_per_doc = avg_chars / chars_per_token
    total_estimated_tokens = (len(ds) * avg_tokens_per_doc) / 1e9

    print(f"    Sampled {sample_size:,} documents")
    print(f"    Average chars per doc: {avg_chars:,.0f}")
    print(f"    Average tokens per doc: {avg_tokens_per_doc:,.0f}")
    print(f"    Estimated total tokens: {total_estimated_tokens:.1f}B")

# Save
print(f"\nSaving dataset to: {OUTPUT_PATH}")
ds.save_to_disk(OUTPUT_PATH)
print(f"✓ Saved!")

print(f"\n{'='*80}")
print("Done! Load with:")
print(f"  from datasets import load_from_disk")
print(f"  ds = load_from_disk('{OUTPUT_PATH}')")
print("="*80)
