from datasets import load_dataset
import numpy as np
from collections import defaultdict

# Configuration
SAMPLE_SIZE = 50000  # Sample repos for estimation
PERCENTILES = [80, 90, 95, 99]
BYTES_PER_TOKEN = 3.5  # Approximate for Python code

# Load dataset
print("Loading dataset (streaming)...")
ds = load_dataset("bigcode/the-stack-v2-train-smol-ids", 
                  split="train", streaming=False)

# Collect repository sizes
print(f"Sampling {SAMPLE_SIZE:,} repositories...")
python_repo_tokens = []
total_repos_seen = 0
python_repos_seen = 0

for repo in ds:
    total_repos_seen += 1
    
    # Calculate Python tokens in this repo
    python_bytes = sum(
        f['length_bytes'] 
        for f in repo['files'] 
        if f.get('language') == 'Python'
        and not f.get('is_vendor', False)
        and not f.get('is_generated', False)
    )
    
    if python_bytes > 0:
        tokens = python_bytes / BYTES_PER_TOKEN
        python_repo_tokens.append(tokens)
        python_repos_seen += 1
    
    if python_repos_seen >= SAMPLE_SIZE:
        break
    
    if total_repos_seen % 10000 == 0:
        print(f"  Processed {total_repos_seen:,} repos, "
              f"found {python_repos_seen:,} with Python...")

python_repo_tokens = np.array(python_repo_tokens)

# Calculate statistics
print(f"\n{'='*60}")
print(f"Sample Statistics:")
print(f"{'='*60}")
print(f"Repositories sampled: {total_repos_seen:,}")
print(f"Repositories with Python: {python_repos_seen:,} "
      f"({python_repos_seen/total_repos_seen*100:.1f}%)")
print(f"\nToken Distribution:")
print(f"  Min: {python_repo_tokens.min():,.0f}")
print(f"  Median: {np.median(python_repo_tokens):,.0f}")
print(f"  Mean: {python_repo_tokens.mean():,.0f}")
print(f"  Max: {python_repo_tokens.max():,.0f}")
print(f"  Total in sample: {python_repo_tokens.sum():,.0f}")

# Estimate total dataset size
# Assume sample is representative
estimated_total_python_repos = python_repos_seen / (total_repos_seen / 104_200_000)
print(f"\nEstimated total Python repos in dataset: {estimated_total_python_repos:,.0f}")

# Calculate percentile cutoffs and tokens above each
print(f"\n{'='*60}")
print(f"Percentile Analysis:")
print(f"{'='*60}")
print(f"{'Percentile':<12} {'Cutoff (tokens)':<18} "
      f"{'Repos Above':<15} {'Tokens Above':<18} {'% of Total':<12}")
print(f"{'-'*60}")

for percentile in PERCENTILES:
    cutoff = np.percentile(python_repo_tokens, percentile)
    repos_above = (python_repo_tokens >= cutoff).sum()
    tokens_above = python_repo_tokens[python_repo_tokens >= cutoff].sum()
    
    # Estimate for full dataset
    pct_repos_above = repos_above / len(python_repo_tokens)
    estimated_repos_above = estimated_total_python_repos * pct_repos_above
    estimated_tokens_above = (tokens_above / len(python_repo_tokens)) * estimated_total_python_repos #* python_repo_tokens.mean()
    pct_of_total = (tokens_above / python_repo_tokens.sum()) * 100
    
    print(f"{percentile}th{'':<8} {cutoff:>13,.0f}{'':<5} "
          f"{int(estimated_repos_above):>10,}{'':<5} "
          f"{estimated_tokens_above/1e9:>10.1f}B{'':<6} "
          f"{pct_of_total:>8.1f}%")

# Additional: Show token ranges
print(f"\n{'='*60}")
print(f"Token Range Distribution:")
print(f"{'='*60}")

ranges = [
    (0, 1_000, "< 1k"),
    (1_000, 5_000, "1k-5k"),
    (5_000, 10_000, "5k-10k"),
    (10_000, 50_000, "10k-50k"),
    (50_000, 100_000, "50k-100k"),
    (100_000, 500_000, "100k-500k"),
    (500_000, float('inf'), "> 500k")
]

print(f"{'Range':<15} {'Repos':<15} {'Total Tokens':<18} {'% of Total':<12}")
print(f"{'-'*60}")

for low, high, label in ranges:
    mask = (python_repo_tokens >= low) & (python_repo_tokens < high)
    repos_in_range = mask.sum()
    tokens_in_range = python_repo_tokens[mask].sum()
    pct = (tokens_in_range / python_repo_tokens.sum()) * 100
    
    estimated_tokens = (tokens_in_range / len(python_repo_tokens)) * estimated_total_python_repos #* python_repo_tokens.mean()
    
    print(f"{label:<15} {repos_in_range:>10,}{'':<5} "
          f"{estimated_tokens/1e9:>10.1f}B{'':<6} {pct:>8.1f}%")

# Recommendation for 20B tokens
print(f"\n{'='*60}")
print(f"Recommendation for 20B Token Budget:")
print(f"{'='*60}")

for percentile in PERCENTILES:
    cutoff = np.percentile(python_repo_tokens, percentile)
    tokens_above = python_repo_tokens[python_repo_tokens >= cutoff].sum()
    estimated_tokens_above = (tokens_above / len(python_repo_tokens)) * estimated_total_python_repos #* python_repo_tokens.mean()
    
    if estimated_tokens_above / 1e9 >= 20:
        print(f"✓ {percentile}th percentile (>{cutoff:,.0f} tokens): "
              f"{estimated_tokens_above/1e9:.1f}B tokens available")
    else:
        print(f"✗ {percentile}th percentile (>{cutoff:,.0f} tokens): "
              f"Only {estimated_tokens_above/1e9:.1f}B tokens (insufficient)")

def get_tokens_in_range(lower, upper):
    mask = (python_repo_tokens >= lower) & (python_repo_tokens < upper)
    tokens_in_range = python_repo_tokens[mask].sum()
    estimated_tokens = (tokens_in_range / len(python_repo_tokens)) * estimated_total_python_repos
    print(f"Estimated tokens in range {lower//1_000}k-{upper//1_000}k: {estimated_tokens/1e9:.1f}B")

print(f"\n{'='*60}")

for low in [2**14, 2**15, 2**16, 2**17]:
    for high in [2**15, 2**16, 2**17, 2**18]:
        if high > low:
            get_tokens_in_range(low, high)
