"""
Estimate tokens per topic and length bucket based on compressed file sizes.
Length buckets: 2e13 = 2^13-2^14 tokens, 2e15 = 2^15-2^16 tokens, etc.
"""

from huggingface_hub import HfFileSystem
from collections import defaultdict
import re

DATASET_NAME = "allenai/dolma3_longmino_pool"

print("="*80)
print(f"Estimating Tokens: {DATASET_NAME}")
print("="*80)

print("\nFetching file list with sizes...")
fs = HfFileSystem()

# List all files in the dataset with their metadata
repo_path = f"datasets/{DATASET_NAME}"
files_info = fs.ls(repo_path, detail=True, recursive=True)

# Filter for data files
data_files = [f for f in files_info if f['name'].endswith('.jsonl.zst') and '/data/' in f['name']]

print(f"Found {len(data_files):,} data files\n")

# Parse and group by topic and length bucket
# Format: data/{source}-{topic}-{length_bucket}/{shard}.jsonl.zst
topic_bucket_files = defaultdict(lambda: defaultdict(list))

for file_info in data_files:
    path = file_info['name']
    size = file_info['size']

    # Extract topic and bucket from path
    # Example: datasets/allenai/dolma3_longmino_pool/data/olmocr_science_pdfs-science_tech-2e15/p020_shard_00002809.jsonl.zst
    match = re.search(r'/data/([^/]+)-(\d+e\d+)/', path)
    if match:
        topic = match.group(1)
        bucket = match.group(2)
        topic_bucket_files[topic][bucket].append({
            'path': path,
            'size': size
        })

print("="*80)
print("Files and Size per Topic and Bucket:")
print("="*80)

# Calculate totals
all_results = []

for topic in sorted(topic_bucket_files.keys()):
    print(f"\n{topic}:")

    buckets_data = topic_bucket_files[topic]
    topic_total_size = 0
    topic_total_files = 0

    for bucket in sorted(buckets_data.keys(), key=lambda x: int(x.split('e')[1])):
        files = buckets_data[bucket]
        total_size = sum(f['size'] for f in files)

        topic_total_size += total_size
        topic_total_files += len(files)

        # Parse bucket: 2e15 means 2^15 to 2^16 tokens
        exp = int(bucket.split('e')[1])
        token_range_min = 2 ** exp
        token_range_max = 2 ** (exp + 1)

        # Estimate tokens from compressed size
        # Typical compression ratio for text + zstd: ~5-10x
        # Assuming ~4 bytes per token in JSON (with overhead)
        # Compressed size * compression_ratio / bytes_per_token
        compression_ratio = 7  # Conservative estimate
        bytes_per_token = 4  # JSON overhead + token data

        estimated_tokens = (total_size * compression_ratio) / bytes_per_token

        all_results.append({
            'topic': topic,
            'bucket': bucket,
            'exponent': exp,
            'token_range': f"2^{exp}-2^{exp+1}",
            'token_range_readable': f"{token_range_min:,}-{token_range_max:,}",
            'num_files': len(files),
            'total_size_mb': total_size / (1024**2),
            'estimated_tokens': estimated_tokens,
            'estimated_tokens_millions': estimated_tokens / 1e6
        })

        print(f"  {bucket} (2^{exp}-2^{exp+1}): "
              f"{len(files):4d} files, "
              f"{total_size/(1024**2):8.1f} MB, "
              f"~{estimated_tokens/1e6:8.1f}M tokens")

    print(f"  {'TOTAL':<5}: "
          f"{topic_total_files:4d} files, "
          f"{topic_total_size/(1024**2):8.1f} MB")

# Summary table
print("\n" + "="*80)
print("Summary Table:")
print("="*80)
print(f"{'Topic':<40} {'Bucket':<8} {'Range':<20} {'Files':>6} {'Size(MB)':>10} {'Est Tokens(M)':>15}")
print("-"*80)

for result in sorted(all_results, key=lambda x: (x['topic'], x['exponent'])):
    print(f"{result['topic']:<40} "
          f"{result['bucket']:<8} "
          f"{result['token_range_readable']:<20} "
          f"{result['num_files']:>6,} "
          f"{result['total_size_mb']:>10.1f} "
          f"{result['estimated_tokens_millions']:>15.1f}")

# Totals by bucket across all topics
print("\n" + "="*80)
print("Totals by Bucket (across all topics):")
print("="*80)

bucket_totals = defaultdict(lambda: {'files': 0, 'size': 0, 'tokens': 0})
for result in all_results:
    bucket_totals[result['bucket']]['files'] += result['num_files']
    bucket_totals[result['bucket']]['size'] += result['total_size_mb']
    bucket_totals[result['bucket']]['tokens'] += result['estimated_tokens_millions']
    bucket_totals[result['bucket']]['exponent'] = result['exponent']

for bucket in sorted(bucket_totals.keys(), key=lambda x: bucket_totals[x]['exponent']):
    data = bucket_totals[bucket]
    exp = data['exponent']
    print(f"{bucket} (2^{exp}-2^{exp+1}): "
          f"{data['files']:6,} files, "
          f"{data['size']:10.1f} MB, "
          f"~{data['tokens']:10.1f}M tokens")

# Grand total
total_files = sum(len(f) for buckets in topic_bucket_files.values() for f in buckets.values())
total_size = sum(f['size'] for f in data_files) / (1024**2)
total_tokens = sum(r['estimated_tokens_millions'] for r in all_results)

print(f"\n{'GRAND TOTAL'}: "
      f"{total_files:6,} files, "
      f"{total_size:10.1f} MB, "
      f"~{total_tokens:10.1f}M tokens")

print("\n" + "="*80)
print("Note: Token estimates assume ~7x compression ratio and ~4 bytes/token in JSON")
print("Actual token counts may vary depending on the tokenizer used")
print("="*80)
