"""
Fetch test files from Software Heritage S3 for tokenizer evaluation.
Samples 100 files from repos in the 32k-64k token bucket, iterating BACKWARDS through dataset.
"""

from datasets import load_from_disk
import random
import boto3
import botocore
from smart_open import open
from pathlib import Path
from typing import List, Dict
import time

# Configuration
BUCKET_TARGET = 6  # 32k-64k tokens
NUM_FILES = 100
OUTPUT_DIR = Path("test_files")

# S3 client (no credentials needed - public bucket)
s3_client = boto3.client(
    's3',
    config=botocore.client.Config(signature_version=botocore.UNSIGNED)
)


def fetch_file_content_s3(blob_id: str, encoding: str = 'utf-8') -> str:
    """
    Fetch file content from Software Heritage S3.

    Args:
        blob_id: Software Heritage blob ID (SWHID)
        encoding: Source encoding of the file

    Returns:
        File content as string, or None if failed
    """
    # Construct S3 URL
    s3_url = f"s3://softwareheritage/content/{blob_id}"

    try:
        # Files are gzipped on S3
        with open(s3_url, "rb", compression=".gz",
                  transport_params={"client": s3_client}) as f:
            content = f.read().decode(encoding)
        return content
    except Exception as e:
        print(f"  Error fetching {blob_id}: {e}")
        return None


def sample_files_from_repos_reverse(ds, bucket: int, num_files: int) -> List[Dict]:
    """
    Sample files from repos in the specified bucket.
    Iterates BACKWARDS through dataset (from end to start).
    Returns list of dicts with repo info and file info.
    """
    print(f"Sampling files from bucket {bucket} (iterating BACKWARDS through dataset)...")

    sampled_files = []
    repos_seen = 0
    repos_in_bucket = 0

    dataset_len = len(ds)
    print(f"Dataset has {dataset_len:,} total repos")

    # Iterate backwards through dataset using negative indices
    for idx in range(dataset_len - 1, -1, -1):
        repos_seen += 1
        repo = ds[idx]

        # Check if repo is in target bucket
        if repo.get('token_size_bucket') != bucket:
            continue

        repos_in_bucket += 1

        # Filter Python files (non-vendor, non-generated)
        python_files = [
            f for f in repo['files']
            if f.get('language') == 'Python'
            and not f.get('is_vendor', False)
            and not f.get('is_generated', False)
        ]

        if not python_files:
            continue

        # Sample a random file from this repo
        file_info = random.choice(python_files)

        sampled_files.append({
            'repo_name': repo.get('repo_name', 'unknown'),
            'file_path': file_info['path'],
            'blob_id': file_info['blob_id'],
            'size_bytes': file_info['length_bytes'],
            'encoding': file_info.get('src_encoding', 'utf-8'),
        })

        print(f"Sampled {len(sampled_files)}/{num_files}: "
              f"{repo.get('repo_name', 'unknown')}/{file_info['path']} "
              f"(scanned {repos_seen:,} repos from end, found {repos_in_bucket:,} in bucket)")

        # Stop when we have enough files
        if len(sampled_files) >= num_files:
            break

    print(f"\nDone! Sampled {len(sampled_files)} files from {repos_in_bucket:,} repos "
          f"(scanned {repos_seen:,} repos from end of dataset)")

    return sampled_files


def main():
    # Load bucketed dataset
    print("Loading bucketed dataset...")
    cache_dir = Path.home() / '.cache' / 'huggingface' / 'datasets'
    dataset_path = cache_dir / "stack-v2-smol-ids-bucketed"

    if not dataset_path.exists():
        print(f"Error: Dataset not found at {dataset_path}")
        print("Please run make-stack-v2-smol-ids-bucketed.py first")
        return

    ds = load_from_disk(str(dataset_path))

    # Sample files (iterating backwards)
    print(f"\nSampling {NUM_FILES} test files from bucket {BUCKET_TARGET} (32k-64k tokens)...")
    sampled_files = sample_files_from_repos_reverse(ds, BUCKET_TARGET, NUM_FILES)

    if len(sampled_files) < NUM_FILES:
        print(f"\nWarning: Only found {len(sampled_files)} files")

    # Create output directory
    OUTPUT_DIR.mkdir(exist_ok=True)

    # Fetch file contents from S3
    print(f"\nFetching file contents from Software Heritage S3...")
    successful = 0

    for i, file_info in enumerate(sampled_files, 1):
        print(f"\n[{i}/{len(sampled_files)}] {file_info['repo_name']}/{file_info['file_path']}")
        print(f"  Blob ID: {file_info['blob_id']}")
        print(f"  Encoding: {file_info['encoding']}")

        content = fetch_file_content_s3(
            file_info['blob_id'],
            file_info['encoding']
        )

        if content:
            # Save to file
            safe_filename = f"file_{i:02d}_{file_info['file_path'].replace('/', '_')}"
            output_path = OUTPUT_DIR / safe_filename

            output_path.write_text(content, encoding='utf-8')
            print(f"  ✓ Saved: {output_path} ({len(content)} chars)")
            successful += 1
        else:
            print(f"  ✗ Failed to fetch")

    print(f"\n{'='*60}")
    print(f"Summary:")
    print(f"{'='*60}")
    print(f"Files sampled: {len(sampled_files)}")
    print(f"Files fetched successfully: {successful}")
    print(f"Output directory: {OUTPUT_DIR.absolute()}")

    if successful > 0:
        print(f"\nTest files ready for tokenizer evaluation!")
        print(f"These files are from the END of the dataset (different from training set)")


if __name__ == "__main__":
    random.seed(123)  # Different seed from training set
    main()
