"""
Estimate total tokens in dataset buckets by sampling repos.
Samples 100 repos per bucket (4-8), tokenizes them, and extrapolates.
"""

from datasets import load_from_disk
from tokenizers import Tokenizer
from pathlib import Path
import boto3
import botocore
from smart_open import open
import random
from tqdm import tqdm
import json

# Configuration
CACHE_DIR = Path.home() / '.cache' / 'huggingface' / 'datasets'
TOKENIZER_PATH = Path("tokenizer_32k.json")
BUCKETS_TO_ESTIMATE = [4, 5, 6, 7, 8]  # 8k-16k, 16k-32k, 32k-64k, 64k-128k, 128k-256k
REPOS_PER_BUCKET = 100

# S3 client
s3_client = boto3.client(
    's3',
    config=botocore.client.Config(signature_version=botocore.UNSIGNED)
)

# Bucket metadata
BUCKET_INFO = {
    4: {"name": "8k-16k", "repos": 589_346},
    5: {"name": "16k-32k", "repos": 297_736},
    6: {"name": "32k-64k", "repos": 125_529},
    7: {"name": "64k-128k", "repos": 47_661},
    8: {"name": "128k-256k", "repos": 17_012},
}


def fetch_file_content_s3(blob_id: str, encoding: str = 'utf-8') -> str:
    """Fetch file content from Software Heritage S3."""
    s3_url = f"s3://softwareheritage/content/{blob_id}"

    try:
        with open(s3_url, "rb", compression=".gz",
                  transport_params={"client": s3_client}) as f:
            content = f.read().decode(encoding)
        return content
    except Exception:
        return None


def depth_augmented_sort(files):
    """Sort files by depth, then alphabetically."""
    return sorted(files, key=lambda f: (f['path'].count('/'), f['path']))


def construct_repo_sequence(repo):
    """
    Construct full repo sequence with StarCoder formatting.
    Returns sequence string or None if failed.
    """
    repo_name = repo.get('repo_name', 'unknown')

    # Filter Python files
    python_files = [
        f for f in repo['files']
        if f.get('language') == 'Python'
        and not f.get('is_vendor', False)
        and not f.get('is_generated', False)
    ]

    if not python_files:
        return None

    # Sort files
    python_files = depth_augmented_sort(python_files)

    # Construct sequence
    parts = [f"<reponame>{repo_name}"]

    for file_info in python_files:
        content = fetch_file_content_s3(
            file_info['blob_id'],
            file_info.get('src_encoding', 'utf-8')
        )

        if content is None:
            continue

        parts.append(f"<filename>{file_info['path']}\n{content}")

    if len(parts) == 1:  # Only repo name, no files fetched
        return None

    parts.append("<|endoftext|>")
    return "\n".join(parts)


def estimate_bucket(bucket_id: int, tokenizer: Tokenizer):
    """
    Estimate total tokens in a bucket by sampling.
    Returns dict with statistics.
    """
    bucket_name = BUCKET_INFO[bucket_id]["name"]
    total_repos = BUCKET_INFO[bucket_id]["repos"]

    print(f"\n{'='*60}")
    print(f"Bucket {bucket_id} ({bucket_name}): {total_repos:,} repos")
    print(f"{'='*60}")

    # Load dataset
    dataset_path = CACHE_DIR / f"stack-v2-smol-ids-bucket-{bucket_id:02d}"
    if not dataset_path.exists():
        print(f"  ERROR: Dataset not found at {dataset_path}")
        return None

    ds = load_from_disk(str(dataset_path))
    dataset_size = len(ds)

    # Sample repos
    sample_indices = random.sample(range(dataset_size), min(REPOS_PER_BUCKET, dataset_size))
    sample_repos = ds.select(sample_indices)

    print(f"  Sampling {len(sample_repos)} repos...")

    # Tokenize each repo
    token_counts = []
    failed = 0

    for repo in tqdm(sample_repos, desc="  Processing"):
        # Construct sequence
        sequence = construct_repo_sequence(repo)

        if sequence is None:
            failed += 1
            continue

        # Tokenize
        tokens = tokenizer.encode(sequence).ids
        token_counts.append(len(tokens))

    if not token_counts:
        print(f"  ERROR: No repos successfully processed")
        return None

    # Calculate statistics
    avg_tokens = sum(token_counts) / len(token_counts)
    min_tokens = min(token_counts)
    max_tokens = max(token_counts)

    # Estimate total
    successful_rate = len(token_counts) / len(sample_repos)
    estimated_valid_repos = int(total_repos * successful_rate)
    estimated_total_tokens = avg_tokens * estimated_valid_repos

    # Results
    results = {
        'bucket_id': bucket_id,
        'bucket_name': bucket_name,
        'total_repos': total_repos,
        'sampled_repos': len(sample_repos),
        'successful_repos': len(token_counts),
        'failed_repos': failed,
        'success_rate': successful_rate,
        'avg_tokens_per_repo': avg_tokens,
        'min_tokens': min_tokens,
        'max_tokens': max_tokens,
        'estimated_valid_repos': estimated_valid_repos,
        'estimated_total_tokens': estimated_total_tokens,
        'estimated_total_tokens_billions': estimated_total_tokens / 1e9,
    }

    print(f"\n  Results:")
    print(f"    Successful: {len(token_counts)}/{len(sample_repos)} "
          f"({successful_rate*100:.1f}%)")
    print(f"    Avg tokens/repo: {avg_tokens:,.0f}")
    print(f"    Range: {min_tokens:,} - {max_tokens:,}")
    print(f"    Estimated valid repos: {estimated_valid_repos:,}")
    print(f"    Estimated total tokens: {estimated_total_tokens/1e9:.2f}B")

    return results


def main():
    print("="*60)
    print("Token Estimation for Dataset Buckets")
    print("="*60)
    print(f"Tokenizer: {TOKENIZER_PATH}")
    print(f"Buckets: {BUCKETS_TO_ESTIMATE}")
    print(f"Samples per bucket: {REPOS_PER_BUCKET}")

    # Load tokenizer
    if not TOKENIZER_PATH.exists():
        print(f"\nERROR: Tokenizer not found at {TOKENIZER_PATH}")
        print("Train a tokenizer first or update TOKENIZER_PATH")
        return

    print(f"\nLoading tokenizer...")
    tokenizer = Tokenizer.from_file(str(TOKENIZER_PATH))
    print(f"  Vocab size: {tokenizer.get_vocab_size():,}")

    # Process each bucket
    all_results = []

    for bucket_id in BUCKETS_TO_ESTIMATE:
        result = estimate_bucket(bucket_id, tokenizer)
        if result:
            all_results.append(result)

    # Summary
    print(f"\n{'='*60}")
    print(f"Summary:")
    print(f"{'='*60}")
    print(f"{'Bucket':<8} {'Range':<12} {'Repos':>12} {'Tokens/Repo':>12} {'Total Tokens':>15}")
    print(f"{'-'*60}")

    total_tokens_all = 0
    for r in all_results:
        print(f"{r['bucket_id']:<8} "
              f"{r['bucket_name']:<12} "
              f"{r['total_repos']:>12,} "
              f"{r['avg_tokens_per_repo']:>12,.0f} "
              f"{r['estimated_total_tokens_billions']:>12.2f}B")
        total_tokens_all += r['estimated_total_tokens']

    print(f"{'-'*60}")
    print(f"{'Total':<33} {total_tokens_all/1e9:>27.2f}B")

    # Save results
    output_path = Path("bucket_token_estimates.json")
    with open(output_path, 'w') as f:
        json.dump({
            'tokenizer': str(TOKENIZER_PATH),
            'samples_per_bucket': REPOS_PER_BUCKET,
            'buckets': all_results,
            'total_estimated_tokens': total_tokens_all,
            'total_estimated_tokens_billions': total_tokens_all / 1e9,
        }, f, indent=2)

    print(f"\nResults saved to: {output_path}")


if __name__ == "__main__":
    random.seed(42)
    main()
