"""
Bootstrap tokenizer training: iteratively train on easier-to-encode subsets.

Algorithm:
1. Train tokenizer on all files
2. Evaluate compression on all original files
3. Keep easiest 80% (exclude hardest 20%)
4. Retrain on filtered 80%
5. Repeat for N iterations

The filtered set changes each iteration based on the current tokenizer.
"""

from tokenizers import Tokenizer, models, trainers, pre_tokenizers, normalizers
from pathlib import Path
from typing import List, Tuple, Dict
from tqdm import tqdm
import json
import os

# Configuration
VOCAB_SIZE = 2**14  # 32,768
INPUT_DIR = Path("sample_files")
TEST_DIR = Path("test_files")
OUTPUT_DIR = Path(f"bootstrap_tokenizers_{VOCAB_SIZE // 2**10}k")
KEEP_PERCENTAGE = 0.80  # Keep easiest 80%, exclude hardest 20%
NUM_ITERATIONS = 5

# StarCoder-style special tokens
SPECIAL_TOKENS = [
    "<|endoftext|>",
    "<reponame>",
    "<filename>",
    "<gh_stars>",
    "<fim_prefix>",
    "<fim_middle>",
    "<fim_suffix>",
    "<fim_pad>",
]


def get_all_training_files() -> List[Path]:
    """Get all training files."""
    if not INPUT_DIR.exists():
        raise FileNotFoundError(f"Input directory not found: {INPUT_DIR}")

    files = list(INPUT_DIR.glob("file_*"))
    if not files:
        raise FileNotFoundError(f"No files found in {INPUT_DIR}")

    return files


def get_test_files() -> List[Path]:
    """Get test files if available."""
    if not TEST_DIR.exists():
        return []
    return list(TEST_DIR.glob("file_*"))


def train_tokenizer(training_files: List[str], iteration: int) -> Tokenizer:
    """Train a BPE tokenizer on given files."""
    print(f"\n  Training tokenizer on {len(training_files):,} files...")

    # Initialize BPE tokenizer
    tokenizer = Tokenizer(models.BPE())
    tokenizer.normalizer = normalizers.Sequence([])
    tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)

    # Trainer
    trainer = trainers.BpeTrainer(
        vocab_size=VOCAB_SIZE,
        special_tokens=SPECIAL_TOKENS,
        show_progress=True,
        initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
    )

    # Train
    tokenizer.train(files=training_files, trainer=trainer)

    # Save
    OUTPUT_DIR.mkdir(exist_ok=True)
    output_path = OUTPUT_DIR / f"tokenizer_iter{iteration}.json"
    tokenizer.save(str(output_path))
    print(f"  Saved: {output_path}")

    return tokenizer


def evaluate_compression(tokenizer: Tokenizer, files: List[Path]) -> List[Tuple[Path, float, int, int]]:
    """
    Evaluate compression on files.
    Returns list of (file_path, chars_per_token, num_chars, num_tokens)
    """
    results = []

    for file_path in tqdm(files, desc="  Evaluating", unit="file"):
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                text = f.read()

            if len(text) == 0:
                continue

            tokens = len(tokenizer.encode(text).ids)
            chars_per_token = len(text) / tokens
            results.append((file_path, chars_per_token, len(text), tokens))
        except Exception as e:
            # Skip files that can't be read
            continue

    return results


def compute_metrics(results: List[Tuple[Path, float, int, int]]) -> Dict:
    """Compute aggregate metrics from compression results."""
    if not results:
        return {}

    compressions = [r[1] for r in results]
    total_chars = sum(r[2] for r in results)
    total_tokens = sum(r[3] for r in results)

    return {
        'avg_compression': sum(compressions) / len(compressions),
        'weighted_compression': total_chars / total_tokens,
        'min_compression': min(compressions),
        'max_compression': max(compressions),
        'num_files': len(results)
    }


def main():
    print("="*60)
    print("Bootstrap Tokenizer Training")
    print("="*60)
    print(f"Vocab size: {VOCAB_SIZE:,}")
    print(f"Keep percentage: {KEEP_PERCENTAGE*100:.0f}%")
    print(f"Iterations: {NUM_ITERATIONS}")
    print(f"Training dir: {INPUT_DIR}")
    print(f"Test dir: {TEST_DIR}")

    # Get all files
    all_training_files = get_all_training_files()
    test_files = get_test_files()

    print(f"\nTotal training files: {len(all_training_files):,}")
    print(f"Test files: {len(test_files):,}")

    # Track metrics across iterations
    all_metrics = []

    # Current training set (starts with all files)
    current_training_files = all_training_files

    for iteration in range(NUM_ITERATIONS):
        print(f"\n{'='*60}")
        print(f"Iteration {iteration}")
        print(f"{'='*60}")
        print(f"Training on: {len(current_training_files):,} files "
              f"({len(current_training_files)/len(all_training_files)*100:.1f}%)")

        # Train tokenizer
        tokenizer = train_tokenizer(
            [str(f) for f in current_training_files],
            iteration
        )

        # Evaluate on ALL original training files (not just current subset)
        print(f"\n  Evaluating on all {len(all_training_files):,} original training files...")
        train_results = evaluate_compression(tokenizer, all_training_files)
        train_metrics = compute_metrics(train_results)

        # Evaluate on test set
        test_metrics = {}
        if test_files:
            print(f"  Evaluating on {len(test_files):,} test files...")
            test_results = evaluate_compression(tokenizer, test_files)
            test_metrics = compute_metrics(test_results)

        # Print metrics
        print(f"\n  Metrics (Iteration {iteration}):")
        print(f"    Training set (all {len(all_training_files):,} files):")
        print(f"      Weighted compression: {train_metrics['weighted_compression']:.2f} chars/token")
        print(f"      Average compression:  {train_metrics['avg_compression']:.2f} chars/token")
        print(f"      Range: {train_metrics['min_compression']:.2f} - {train_metrics['max_compression']:.2f}")

        if test_metrics:
            print(f"    Test set ({len(test_files):,} files):")
            print(f"      Weighted compression: {test_metrics['weighted_compression']:.2f} chars/token")
            print(f"      Average compression:  {test_metrics['avg_compression']:.2f} chars/token")

        # Store metrics
        all_metrics.append({
            'iteration': iteration,
            'num_training_files': len(current_training_files),
            'train_metrics': train_metrics,
            'test_metrics': test_metrics
        })

        # Select easiest files for next iteration (if not last iteration)
        if iteration < NUM_ITERATIONS - 1:
            # Sort by compression (higher = better = easier to encode)
            train_results.sort(key=lambda x: x[1], reverse=True)

            # Keep top KEEP_PERCENTAGE
            num_to_keep = int(len(all_training_files) * KEEP_PERCENTAGE)
            current_training_files = [r[0] for r in train_results[:num_to_keep]]

            print(f"\n  For next iteration: keeping easiest {num_to_keep:,} files "
                  f"({KEEP_PERCENTAGE*100:.0f}%), excluding hardest {len(all_training_files) - num_to_keep:,}")

            # Show some excluded files
            excluded = train_results[num_to_keep:]
            if excluded:
                print(f"  Hardest 5 files (excluded):")
                for i, (path, comp, chars, tokens) in enumerate(excluded[-5:][::-1], 1):
                    print(f"    {i}. {comp:.2f} chars/token - {path.name}")

    # Save metrics summary
    metrics_path = OUTPUT_DIR / "bootstrap_metrics.json"
    with open(metrics_path, 'w') as f:
        json.dump(all_metrics, f, indent=2)
    print(f"\n{'='*60}")
    print(f"Metrics saved to: {metrics_path}")

    # Print summary comparison
    print(f"\n{'='*60}")
    print(f"Summary Comparison:")
    print(f"{'='*60}")
    print(f"{'Iter':<6} {'Train Files':<15} {'Train Comp':<15} {'Test Comp':<15}")
    print(f"{'-'*60}")
    for m in all_metrics:
        train_comp = m['train_metrics'].get('weighted_compression', 0)
        test_comp = m['test_metrics'].get('weighted_compression', 0) if m['test_metrics'] else 0
        print(f"{m['iteration']:<6} "
              f"{m['num_training_files']:>10,}{'':<5} "
              f"{train_comp:>10.2f}{'':<5} "
              f"{test_comp:>10.2f}{'':<5}")

    print(f"\n{'='*60}")
    print(f"Done! Tokenizers saved in: {OUTPUT_DIR}/")
    print(f"  - tokenizer_iter0.json (trained on all files)")
    print(f"  - tokenizer_iter1.json (trained on easiest 80%)")
    print(f"  - tokenizer_iter2.json (trained on easiest 80% of iter1)")


if __name__ == "__main__":
    main()
