"""
Fetch sample files from Software Heritage S3 for tokenizer training.
Samples 10 files from repos in the 32k-64k token bucket using S3 access.
"""

from datasets import load_from_disk
import random
import boto3
import botocore
from smart_open import open
from pathlib import Path
from typing import List, Dict
import time

# Configuration
BUCKET_TARGET = 6  # 32k-64k tokens
NUM_FILES = 10000
OUTPUT_DIR = Path("sample_files")

# S3 client (no credentials needed - public bucket)
s3_client = boto3.client(
    's3',
    config=botocore.client.Config(signature_version=botocore.UNSIGNED)
)


def fetch_file_content_s3(blob_id: str, encoding: str = 'utf-8') -> str:
    """
    Fetch file content from Software Heritage S3.

    Args:
        blob_id: Software Heritage blob ID (SWHID)
        encoding: Source encoding of the file

    Returns:
        File content as string, or None if failed
    """
    # Construct S3 URL
    s3_url = f"s3://softwareheritage/content/{blob_id}"

    try:
        # Files are gzipped on S3
        with open(s3_url, "rb", compression=".gz",
                  transport_params={"client": s3_client}) as f:
            content = f.read().decode(encoding)
        return content
    except Exception as e:
        print(f"  Error fetching {blob_id}: {e}")
        return None


def sample_files_from_repos(ds, bucket: int, num_files: int) -> List[Dict]:
    """
    Sample files from repos in the specified bucket.
    Returns list of dicts with repo info and file info.
    Efficiently samples on-the-fly without pre-filtering entire dataset.
    """
    print(f"Sampling files from bucket {bucket} (streaming through dataset)...")

    sampled_files = []
    repos_seen = 0
    repos_in_bucket = 0

    # Stream through dataset and sample when we find matching repos
    for repo in ds:
        repos_seen += 1

        # Check if repo is in target bucket
        if repo.get('token_size_bucket') != bucket:
            continue

        repos_in_bucket += 1

        # Filter Python files (non-vendor, non-generated)
        python_files = [
            f for f in repo['files']
            if f.get('language') == 'Python'
            and not f.get('is_vendor', False)
            and not f.get('is_generated', False)
        ]

        if not python_files:
            continue

        # Sample a random file from this repo
        file_info = random.choice(python_files)

        sampled_files.append({
            'repo_name': repo.get('repo_name', 'unknown'),
            'file_path': file_info['path'],
            'blob_id': file_info['blob_id'],
            'size_bytes': file_info['length_bytes'],
            'encoding': file_info.get('src_encoding', 'utf-8'),
        })

        print(f"Sampled {len(sampled_files)}/{num_files}: "
              f"{repo.get('repo_name', 'unknown')}/{file_info['path']} "
              f"(scanned {repos_seen:,} repos, found {repos_in_bucket:,} in bucket)")

        # Stop when we have enough files
        if len(sampled_files) >= num_files:
            break

    print(f"\nDone! Sampled {len(sampled_files)} files from {repos_in_bucket:,} repos "
          f"(scanned {repos_seen:,} total repos)")

    return sampled_files


def main():
    # Load bucketed dataset
    print("Loading bucketed dataset...")
    cache_dir = Path.home() / '.cache' / 'huggingface' / 'datasets'
    dataset_path = cache_dir / "stack-v2-smol-ids-bucketed"

    if not dataset_path.exists():
        print(f"Error: Dataset not found at {dataset_path}")
        print("Please run make-stack-v2-smol-ids-bucketed.py first")
        return

    ds = load_from_disk(str(dataset_path))

    # Sample files
    print(f"\nSampling {NUM_FILES} files from bucket {BUCKET_TARGET} (32k-64k tokens)...")
    sampled_files = sample_files_from_repos(ds, BUCKET_TARGET, NUM_FILES)

    if len(sampled_files) < NUM_FILES:
        print(f"\nWarning: Only found {len(sampled_files)} files")

    # Create output directory
    OUTPUT_DIR.mkdir(exist_ok=True)

    # Fetch file contents from S3
    print(f"\nFetching file contents from Software Heritage S3...")
    successful = 0

    for i, file_info in enumerate(sampled_files, 1):
        print(f"\n[{i}/{len(sampled_files)}] {file_info['repo_name']}/{file_info['file_path']}")
        print(f"  Blob ID: {file_info['blob_id']}")
        print(f"  Encoding: {file_info['encoding']}")

        content = fetch_file_content_s3(
            file_info['blob_id'],
            file_info['encoding']
        )

        if content:
            # Save to file
            safe_filename = f"file_{i:02d}_{file_info['file_path'].replace('/', '_')}"
            output_path = OUTPUT_DIR / safe_filename

            output_path.write_text(content, encoding='utf-8')
            print(f"  ✓ Saved: {output_path} ({len(content)} chars)")
            successful += 1
        else:
            print(f"  ✗ Failed to fetch")

    print(f"\n{'='*60}")
    print(f"Summary:")
    print(f"{'='*60}")
    print(f"Files sampled: {len(sampled_files)}")
    print(f"Files fetched successfully: {successful}")
    print(f"Output directory: {OUTPUT_DIR.absolute()}")

    if successful > 0:
        print(f"\nNext steps:")
        print(f"1. Train tokenizer on files in {OUTPUT_DIR}/")
        print(f"2. Use train_tokenizer.py for vocab size 2^14 = 16,384")


if __name__ == "__main__":
    random.seed(42)  # For reproducibility
    main()
