"""
Construct full repo sequences with StarCoder-style formatting.
Concatenates files from repos with special tokens, ordered by depth-augmented alphabetical.
"""

from datasets import load_from_disk
from pathlib import Path
import boto3
import botocore
from smart_open import open
from tqdm import tqdm

# Configuration
BUCKET_DATASET_PATH = Path.home() / '.cache' / 'huggingface' / 'datasets' / 'stack-v2-smol-ids-bucket-06'
NUM_REPOS = 10
OUTPUT_DIR = Path("repo_sequences")

# S3 client (no credentials needed - public bucket)
s3_client = boto3.client(
    's3',
    config=botocore.client.Config(signature_version=botocore.UNSIGNED)
)


def fetch_file_content_s3(blob_id: str, encoding: str = 'utf-8') -> str:
    """Fetch file content from Software Heritage S3."""
    s3_url = f"s3://softwareheritage/content/{blob_id}"

    try:
        with open(s3_url, "rb", compression=".gz",
                  transport_params={"client": s3_client}) as f:
            content = f.read().decode(encoding)
        return content
    except Exception as e:
        print(f"    Error fetching {blob_id}: {e}")
        return None


def depth_augmented_sort(files):
    """Sort files by depth (shallow first), then alphabetically."""
    return sorted(files, key=lambda f: (f['path'].count('/'), f['path']))


def construct_repo_sequence(repo):
    """
    Construct a full repo sequence with StarCoder formatting.

    Format:
    <reponame>owner/repo<filename>path/file1.py
    [file1 content]
    <filename>path/file2.py
    [file2 content]
    <|endoftext|>
    """
    repo_name = repo.get('repo_name', 'unknown/unknown')

    # Filter Python files (non-vendor, non-generated)
    python_files = [
        f for f in repo['files']
        if f.get('language') == 'Python'
        and not f.get('is_vendor', False)
        and not f.get('is_generated', False)
    ]

    if not python_files:
        return None

    # Sort files by depth-augmented alphabetical
    python_files = depth_augmented_sort(python_files)

    # Construct sequence
    sequence_parts = []

    # Start with repo name
    sequence_parts.append(f"<reponame>{repo_name}")

    # Add each file
    files_fetched = 0
    for file_info in python_files:
        file_path = file_info['path']
        blob_id = file_info['blob_id']
        encoding = file_info.get('src_encoding', 'utf-8')

        # Fetch content
        content = fetch_file_content_s3(blob_id, encoding)

        if content is None:
            continue

        # Add file marker and content
        sequence_parts.append(f"<filename>{file_path}\n{content}")
        files_fetched += 1

    if files_fetched == 0:
        return None

    # End with EOS token
    sequence_parts.append("<|endoftext|>")

    # Join with newlines
    full_sequence = "\n".join(sequence_parts)

    return full_sequence, files_fetched


def main():
    print("="*60)
    print("Constructing Repository Sequences")
    print("="*60)

    # Load dataset
    print(f"Loading dataset from: {BUCKET_DATASET_PATH}")

    if not BUCKET_DATASET_PATH.exists():
        print(f"ERROR: Dataset not found!")
        print(f"Expected: {BUCKET_DATASET_PATH}")
        print("Make sure the bucketing script has completed.")
        return

    ds = load_from_disk(str(BUCKET_DATASET_PATH))
    print(f"Loaded dataset with {len(ds):,} repos")

    # Create output directory
    OUTPUT_DIR.mkdir(exist_ok=True)

    # Process first N repos
    print(f"\nProcessing first {NUM_REPOS} repos...")

    successful = 0
    for i, repo in enumerate(tqdm(ds.select(range(NUM_REPOS)), desc="Repos")):
        repo_name = repo.get('repo_name', f'unknown_{i}')
        safe_name = repo_name.replace('/', '_').replace(' ', '_')

        print(f"\n[{i+1}/{NUM_REPOS}] {repo_name}")

        # Construct sequence
        result = construct_repo_sequence(repo)

        if result is None:
            print(f"  Skipped: No valid Python files")
            continue

        sequence, num_files = result

        # Save to file
        output_path = OUTPUT_DIR / f"repo_{i:07d}_{safe_name}.txt"
        output_path.write_text(sequence, encoding='utf-8')

        print(f"  ✓ Saved: {output_path.name}")
        print(f"    Files: {num_files}")
        print(f"    Length: {len(sequence):,} chars")

        successful += 1

    # Summary
    print(f"\n{'='*60}")
    print(f"Summary:")
    print(f"{'='*60}")
    print(f"Repos processed: {NUM_REPOS}")
    print(f"Sequences created: {successful}")
    print(f"Output directory: {OUTPUT_DIR.absolute()}")

    if successful > 0:
        print(f"\nExample sequence format:")
        example_file = list(OUTPUT_DIR.glob("repo_*.txt"))[0]
        with open(example_file, 'r') as f:
            content = f.read()

        # Show first 1000 chars
        print(content[:1000])
        if len(content) > 1000:
            print(f"\n... ({len(content) - 1000:,} more chars)")


if __name__ == "__main__":
    main()
