"""
Fetch sample files from Software Heritage for tokenizer training.
Samples 10 files from repos in the 32k-64k token bucket.
"""

from datasets import load_from_disk
import random
import requests
from pathlib import Path
from typing import List, Dict
import time

# Configuration
BUCKET_TARGET = 6  # 32k-64k tokens
NUM_FILES = 10
OUTPUT_DIR = Path("sample_files")
SWH_API_BASE = "https://archive.softwareheritage.org/api/1"

# Rate limiting
REQUEST_DELAY = 0.5  # seconds between requests


def fetch_file_content(swhid: str) -> str:
    """
    Fetch file content from Software Heritage using SWHID.

    SWHID format: swh:1:cnt:<hash>
    """
    # Extract hash from SWHID
    if swhid.startswith("swh:1:cnt:"):
        content_hash = swhid.split(":")[-1]
    else:
        raise ValueError(f"Invalid SWHID format: {swhid}")

    # Fetch from Software Heritage API
    url = f"{SWH_API_BASE}/content/sha1_git:{content_hash}/raw/"

    print(f"  Fetching: {url}")

    response = requests.get(url, timeout=30)

    if response.status_code == 200:
        # Content is returned as bytes
        try:
            return response.content.decode('utf-8')
        except UnicodeDecodeError:
            print(f"  Warning: Could not decode as UTF-8, skipping")
            return None
    elif response.status_code == 404:
        print(f"  Warning: Content not found (404)")
        return None
    else:
        print(f"  Error: Status code {response.status_code}")
        return None


def sample_files_from_repos(ds, bucket: int, num_files: int) -> List[Dict]:
    """
    Sample files from repos in the specified bucket.
    Returns list of dicts with repo info and file info.
    """
    # Filter to target bucket
    print(f"Filtering to bucket {bucket}...")
    repos_in_bucket = ds.filter(lambda x: x['token_size_bucket'] == bucket)

    num_repos = len(repos_in_bucket)
    print(f"Found {num_repos:,} repos in bucket {bucket}")

    # Collect Python files from random repos
    sampled_files = []
    attempts = 0
    max_attempts = num_files * 10  # Try up to 10x to get enough files

    while len(sampled_files) < num_files and attempts < max_attempts:
        attempts += 1

        # Pick a random repo
        repo_idx = random.randint(0, num_repos - 1)
        repo = repos_in_bucket[repo_idx]

        # Filter Python files (non-vendor, non-generated)
        python_files = [
            f for f in repo['files']
            if f.get('language') == 'Python'
            and not f.get('is_vendor', False)
            and not f.get('is_generated', False)
        ]

        if not python_files:
            continue

        # Pick a random Python file
        file_info = random.choice(python_files)

        sampled_files.append({
            'repo_name': repo.get('repo_name', 'unknown'),
            'file_path': file_info['path'],
            'swhid': file_info['blob_id'],
            'size_bytes': file_info['length_bytes'],
        })

        print(f"Sampled {len(sampled_files)}/{num_files}: "
              f"{repo.get('repo_name', 'unknown')}/{file_info['path']}")

    return sampled_files


def main():
    # Load bucketed dataset
    print("Loading bucketed dataset...")
    cache_dir = Path.home() / '.cache' / 'huggingface' / 'datasets'
    dataset_path = cache_dir / "stack-v2-smol-ids-bucketed"

    if not dataset_path.exists():
        print(f"Error: Dataset not found at {dataset_path}")
        print("Please run make-stack-v2-smol-ids-bucketed.py first")
        return

    ds = load_from_disk(str(dataset_path))

    # Sample files
    print(f"\nSampling {NUM_FILES} files from bucket {BUCKET_TARGET} (32k-64k tokens)...")
    sampled_files = sample_files_from_repos(ds, BUCKET_TARGET, NUM_FILES)

    if len(sampled_files) < NUM_FILES:
        print(f"\nWarning: Only found {len(sampled_files)} files")

    # Create output directory
    OUTPUT_DIR.mkdir(exist_ok=True)

    # Fetch file contents
    print(f"\nFetching file contents from Software Heritage...")
    successful = 0

    for i, file_info in enumerate(sampled_files, 1):
        print(f"\n[{i}/{len(sampled_files)}] {file_info['repo_name']}/{file_info['file_path']}")

        content = fetch_file_content(file_info['swhid'])

        if content:
            # Save to file
            safe_filename = f"file_{i:02d}_{file_info['file_path'].replace('/', '_')}"
            output_path = OUTPUT_DIR / safe_filename

            output_path.write_text(content, encoding='utf-8')
            print(f"  Saved: {output_path} ({len(content)} chars)")
            successful += 1

        # Rate limiting
        if i < len(sampled_files):
            time.sleep(REQUEST_DELAY)

    print(f"\n{'='*60}")
    print(f"Summary:")
    print(f"{'='*60}")
    print(f"Files sampled: {len(sampled_files)}")
    print(f"Files fetched successfully: {successful}")
    print(f"Output directory: {OUTPUT_DIR.absolute()}")

    if successful > 0:
        print(f"\nNext steps:")
        print(f"1. Train tokenizer on files in {OUTPUT_DIR}/")
        print(f"2. Example tokenizer training script:")
        print(f"   from tokenizers import Tokenizer, trainers, models, pre_tokenizers")
        print(f"   # See train-tokenizer.py for full example")


if __name__ == "__main__":
    random.seed(42)  # For reproducibility
    main()
