"""
Pre-tokenize repositories and save as binary files for training.

Processes buckets 4-6 to generate ~20B tokens total:
- Bucket 4 (8k-16k): 589,346 repos
- Bucket 5 (16k-32k): 297,736 repos
- Bucket 6 (32k-64k): 125,529 repos

Outputs per bucket:
- bucket_0X_train.bin: tokens as uint16
- bucket_0X_train_offsets.bin: sequence start positions as uint64
"""

from datasets import load_from_disk
from tokenizers import Tokenizer
from pathlib import Path
import gzip
import numpy as np
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, wait

# Configuration
CACHE_DIR = Path.home() / '.cache' / 'huggingface' / 'datasets'
VOCAB = "32k"
TOKENIZER_PATH = Path(f"tokenizer_{VOCAB}.json")
OUTPUT_DIR = Path(f"pretokenized_data_{VOCAB}")
BUCKETS_TO_PROCESS = [8]  # 8k-16k, 16k-32k, 32k-64k

# Local blob directories (downloaded with s5cmd)
BLOB_DIR_TEMPLATE = "blob_files_bucket_{bucket}"

# Debug mode: set to small number (e.g., 100) to process subset of repos for testing
# Set to None to process all repos in each bucket
MAX_REPOS_PER_BUCKET = None  # None = process all repos

# Parallelization: process multiple repos concurrently
# This parallelizes both file I/O and tokenization
NUM_PARALLEL_REPOS = 16


def read_local_blob(blob_id: str, blob_dir: Path, encoding: str = 'utf-8') -> str:
    """Read and decompress a blob file from local disk."""
    try:
        blob_path = blob_dir / blob_id
        if not blob_path.exists():
            return None

        # Read and decompress (all SWH files are gzipped)
        with open(blob_path, 'rb') as f:
            body = gzip.decompress(f.read())

        return body.decode(encoding)
    except Exception:
        return None


def depth_augmented_sort(files):
    """Sort files by depth (shallow first), then alphabetically."""
    return sorted(files, key=lambda f: (f['path'].count('/'), f['path']))


def construct_repo_sequence(repo, blob_dir):
    """
    Construct full repo sequence with StarCoder formatting.
    Reads files sequentially from local disk (parallelism is at repo level).
    Returns (sequence_string, success) or (None, False) if failed.
    """
    repo_name = repo.get('repo_name', 'unknown')

    # Filter Python files (non-vendor, non-generated)
    python_files = [
        f for f in repo['files']
        if f.get('language') == 'Python'
        and not f.get('is_vendor', False)
        and not f.get('is_generated', False)
    ]

    if not python_files:
        return None, False

    # Sort files by depth-augmented alphabetical
    python_files = depth_augmented_sort(python_files)

    # Construct sequence - read files sequentially
    parts = [f"<reponame>{repo_name}"]
    files_fetched = 0

    for file_info in python_files:
        content = read_local_blob(
            file_info['blob_id'],
            blob_dir,
            file_info.get('src_encoding', 'utf-8')
        )

        if content is not None:
            parts.append(f"<filename>{file_info['path']}\n{content}")
            files_fetched += 1

    if files_fetched == 0:
        return None, False

    # End with EOS token
    parts.append("<|endoftext|>")

    # Join with newlines
    full_sequence = "\n".join(parts)

    return full_sequence, True


def process_single_repo(repo, blob_dir, tokenizer):
    """
    Process a single repo: construct sequence and tokenize.
    Returns (tokens, success) where tokens is list of token ids.
    """
    sequence, success = construct_repo_sequence(repo, blob_dir)

    if not success:
        return None, False

    # Tokenize
    tokens = tokenizer.encode(sequence).ids
    return tokens, True


def process_bucket(bucket_id: int, tokenizer: Tokenizer):
    """
    Process a bucket: construct sequences, tokenize, write to binary files.
    Returns dict with statistics.
    """
    print(f"\n{'='*60}")
    print(f"Processing Bucket {bucket_id}")
    print(f"{'='*60}")

    # Get blob directory for this bucket
    blob_dir = Path(BLOB_DIR_TEMPLATE.format(bucket=bucket_id))
    if not blob_dir.exists():
        print(f"  ERROR: Blob directory not found at {blob_dir}")
        print(f"  Make sure blobs have been downloaded with s5cmd")
        return None

    # Load dataset
    dataset_path = CACHE_DIR / f"stack-v2-smol-ids-bucket-{bucket_id:02d}"
    if not dataset_path.exists():
        print(f"  ERROR: Dataset not found at {dataset_path}")
        return None

    ds = load_from_disk(str(dataset_path))
    total_repos_in_bucket = len(ds)

    # Determine how many repos to process
    if MAX_REPOS_PER_BUCKET is not None:
        repos_to_process = min(MAX_REPOS_PER_BUCKET, total_repos_in_bucket)
        print(f"  DEBUG MODE: Processing {repos_to_process:,} of {total_repos_in_bucket:,} repos")
        ds = ds.select(range(repos_to_process))
    else:
        repos_to_process = total_repos_in_bucket
        print(f"  Total repos: {repos_to_process:,}")

    # Prepare output files
    OUTPUT_DIR.mkdir(exist_ok=True)
    tokens_path = OUTPUT_DIR / f"bucket_{bucket_id:02d}_train.bin"
    offsets_path = OUTPUT_DIR / f"bucket_{bucket_id:02d}_train_offsets.bin"

    # Accumulate all tokens and offsets in memory
    all_tokens = []
    offsets = [0]  # First sequence starts at position 0
    successful_repos = 0
    failed_repos = 0

    # Process repos in parallel
    print(f"  Processing repos with {NUM_PARALLEL_REPOS} parallel workers...")

    with ThreadPoolExecutor(max_workers=NUM_PARALLEL_REPOS) as executor:
        # Use a sliding window of futures to avoid memory issues
        # Submit initial batch
        repo_iter = iter(ds)
        futures = {}

        # Submit first batch
        for _ in range(NUM_PARALLEL_REPOS * 2):
            try:
                repo = next(repo_iter)
                future = executor.submit(process_single_repo, repo, blob_dir, tokenizer)
                futures[future] = None
            except StopIteration:
                break

        # Process with progress bar
        with tqdm(total=repos_to_process, desc=f"  Bucket {bucket_id}", unit="repo") as pbar:
            while futures:
                # Wait for at least one to complete
                done, _ = wait(futures.keys(), return_when='FIRST_COMPLETED')

                for future in done:
                    tokens, success = future.result()
                    del futures[future]

                    if success:
                        all_tokens.extend(tokens)
                        offsets.append(len(all_tokens))
                        successful_repos += 1
                    else:
                        failed_repos += 1

                    pbar.update(1)

                    # Submit new repo to keep pool full
                    try:
                        repo = next(repo_iter)
                        new_future = executor.submit(process_single_repo, repo, blob_dir, tokenizer)
                        futures[new_future] = None
                    except StopIteration:
                        pass

    # Write tokens to disk
    print(f"  Writing tokens to disk...")
    tokens_array = np.array(all_tokens, dtype=np.uint16)
    tokens_array.tofile(tokens_path)

    # Write offsets to disk
    print(f"  Writing offsets to disk...")
    offsets_array = np.array(offsets, dtype=np.uint64)
    offsets_array.tofile(offsets_path)

    # Calculate statistics
    tokens_size_mb = tokens_path.stat().st_size / (1024 * 1024)
    offsets_size_mb = offsets_path.stat().st_size / (1024 * 1024)
    total_tokens = len(all_tokens)
    success_rate = successful_repos / repos_to_process if repos_to_process > 0 else 0

    # Results
    results = {
        'bucket_id': bucket_id,
        'total_repos': repos_to_process,
        'successful_repos': successful_repos,
        'failed_repos': failed_repos,
        'success_rate': success_rate,
        'total_tokens': total_tokens,
        'total_tokens_billions': total_tokens / 1e9,
        'tokens_file_size_mb': tokens_size_mb,
        'offsets_file_size_mb': offsets_size_mb,
        'num_sequences': len(offsets) - 1,
    }

    print(f"\n  Results:")
    print(f"    Successful repos: {successful_repos:,}/{repos_to_process:,} ({success_rate*100:.1f}%)")
    print(f"    Failed repos: {failed_repos:,}")
    print(f"    Total tokens: {total_tokens/1e9:.2f}B")
    print(f"    Sequences: {results['num_sequences']:,}")
    print(f"    Tokens file: {tokens_size_mb:.1f} MB")
    print(f"    Offsets file: {offsets_size_mb:.1f} MB")

    return results


def main():
    print("="*60)
    print("Pre-tokenize Dataset: Buckets to Binary Files")
    print("="*60)
    print(f"Tokenizer: {TOKENIZER_PATH}")
    print(f"Buckets: {BUCKETS_TO_PROCESS}")
    print(f"Blob dir template: {BLOB_DIR_TEMPLATE}")
    print(f"Output dir: {OUTPUT_DIR}")
    print(f"Parallel repos: {NUM_PARALLEL_REPOS}")
    if MAX_REPOS_PER_BUCKET is not None:
        print(f"DEBUG MODE: Processing max {MAX_REPOS_PER_BUCKET:,} repos per bucket")
    else:
        print(f"Processing ALL repos in each bucket")

    # Load tokenizer
    if not TOKENIZER_PATH.exists():
        print(f"\nERROR: Tokenizer not found at {TOKENIZER_PATH}")
        print("Train a tokenizer first or update TOKENIZER_PATH")
        return

    print(f"\nLoading tokenizer...")
    tokenizer = Tokenizer.from_file(str(TOKENIZER_PATH))
    print(f"  Vocab size: {tokenizer.get_vocab_size():,}")

    # Process each bucket
    all_results = []

    for bucket_id in BUCKETS_TO_PROCESS:
        result = process_bucket(bucket_id, tokenizer)
        if result:
            all_results.append(result)

    # Summary
    print(f"\n{'='*60}")
    print(f"Summary:")
    print(f"{'='*60}")
    print(f"{'Bucket':<8} {'Repos':>12} {'Success':>10} {'Tokens':>15} {'Size (MB)':>12}")
    print(f"{'-'*60}")

    total_tokens_all = 0
    total_size_mb = 0
    total_repos_all = 0
    total_successful = 0

    for r in all_results:
        print(f"{r['bucket_id']:<8} "
              f"{r['total_repos']:>12,} "
              f"{r['success_rate']*100:>9.1f}% "
              f"{r['total_tokens_billions']:>12.2f}B "
              f"{r['tokens_file_size_mb']:>12.1f}")
        total_tokens_all += r['total_tokens']
        total_size_mb += r['tokens_file_size_mb'] + r['offsets_file_size_mb']
        total_repos_all += r['total_repos']
        total_successful += r['successful_repos']

    print(f"{'-'*60}")
    print(f"{'Total':<8} "
          f"{total_repos_all:>12,} "
          f"{total_successful/total_repos_all*100:>9.1f}% "
          f"{total_tokens_all/1e9:>12.2f}B "
          f"{total_size_mb:>12.1f}")

    print(f"\n{'='*60}")
    print(f"Done! Files saved in: {OUTPUT_DIR.absolute()}/")
    print(f"\nGenerated files:")
    for r in all_results:
        print(f"  bucket_{r['bucket_id']:02d}_train.bin")
        print(f"  bucket_{r['bucket_id']:02d}_train_offsets.bin")

    print(f"\nUsage in training:")
    print(f"  # Load tokens")
    print(f"  tokens = np.memmap('bucket_04_train.bin', dtype=np.uint16, mode='r')")
    print(f"  offsets = np.memmap('bucket_04_train_offsets.bin', dtype=np.uint64, mode='r')")
    print(f"  # Get sequence i: tokens[offsets[i]:offsets[i+1]]")


if __name__ == "__main__":
    main()
