from datasets import load_dataset
from pathlib import Path
import os
import math

# Configuration
BYTES_PER_TOKEN = 3.5

def get_token_bucket(tokens):
    """
    Bucket repositories by token count on log2 scale.

    Returns:
        -1: no Python code
         0: < 1k tokens (2^10)
         1: 1k-2k (2^10 to 2^11)
         2: 2k-4k (2^11 to 2^12)
         ...
        10: 512k-1M (2^19 to 2^20)
        11: > 1M tokens (> 2^20)
    """
    if tokens == 0:
        return -1
    elif tokens < 1024:  # < 1k (2^10)
        return 0
    elif tokens >= 1_048_576:  # >= 1M (2^20)
        return 11
    else:
        # For tokens in [1k, 1M), compute which power-of-2 bucket
        # log2(tokens) tells us the power of 2
        # Subtract 10 because bucket 1 starts at 2^10
        log2_tokens = math.log2(tokens)
        bucket = int(log2_tokens) - 9  # bucket 1 = [2^10, 2^11)
        return bucket

def add_bucket_info(repo):
    """Add Python token count and logarithmic bucket."""
    python_bytes = sum(
        f['length_bytes']
        for f in repo['files']
        if f.get('language') == 'Python'
        and not f.get('is_vendor', False)
        and not f.get('is_generated', False)
    )

    tokens = python_bytes / BYTES_PER_TOKEN
    bucket = get_token_bucket(tokens)

    return {
        'python_tokens': tokens,
        'token_size_bucket': bucket
    }

# Main processing
print("Loading dataset from cache...")
ds = load_dataset("bigcode/the-stack-v2-train-smol-ids", split="train")

print(f"Dataset size: {len(ds):,} repositories")

print("Adding bucket information...")
ds = ds.map(
    add_bucket_info,
    num_proc=os.cpu_count() // 2,
    desc="Computing Python tokens and log2 buckets"
)

# Save separate datasets per bucket
cache_dir = Path(os.environ.get(
    'HF_DATASETS_CACHE',
    Path.home() / '.cache' / 'huggingface' / 'datasets'
))

bucket_names = {
    -1: "No Python",
    0: "< 1k",
    1: "1k-2k",
    2: "2k-4k",
    3: "4k-8k",
    4: "8k-16k",
    5: "16k-32k",
    6: "32k-64k",
    7: "64k-128k",
    8: "128k-256k",
    9: "256k-512k",
    10: "512k-1M",
    11: "> 1M"
}

print("\n" + "="*60)
print("Filtering and saving separate datasets per bucket...")
print("="*60)

num_workers = os.cpu_count() // 2
bucket_stats = []

for bucket_id in sorted(bucket_names.keys()):
    bucket_name = bucket_names[bucket_id]
    print(f"\n[Bucket {bucket_id:>2}] {bucket_name:<15}")

    # Filter to this bucket
    bucket_ds = ds.filter(
        lambda x: x['token_size_bucket'] == bucket_id,
        num_proc=num_workers,
        desc=f"Filtering bucket {bucket_id}"
    )

    num_repos = len(bucket_ds)

    if num_repos == 0:
        print(f"  → No repos in this bucket, skipping")
        continue

    # Calculate appropriate number of shards (at least 1, max 64)
    # Aim for ~1M repos per shard
    num_shards = max(1, min(64, num_repos // 1_000_000))

    # Save
    output_path = cache_dir / f"stack-v2-smol-ids-bucket-{bucket_id:02d}"
    bucket_ds.save_to_disk(str(output_path), num_shards=num_shards)

    print(f"  → Saved {num_repos:>12,} repos ({num_shards:>2} shards)")
    print(f"  → Path: {output_path}")

    bucket_stats.append((bucket_id, bucket_name, num_repos, output_path))

# Print summary
print("\n" + "="*60)
print("Summary:")
print("="*60)
print(f"{'Bucket':<8} {'Range':<15} {'Repos':>15} {'Path'}")
print("-"*60)
for bucket_id, name, count, path in bucket_stats:
    print(f"{bucket_id:>6}   {name:<15} {count:>15,}   {path.name}")

total_repos = sum(count for _, _, count, _ in bucket_stats)
print("-"*60)
print(f"{'Total':<24} {total_repos:>15,}")

print("\n" + "="*60)
print("Usage Examples:")
print("="*60)
print(f"# Load a specific bucket (instant!)")
print(f"from datasets import load_from_disk")
print(f"bucket_6 = load_from_disk('{cache_dir}/stack-v2-smol-ids-bucket-06')")
print()
print(f"# Iterate through all repos in 32k-64k range")
print(f"for repo in bucket_6:")
print(f"    # process repo...")
print(f"    pass")

print("\nDone!")
