"""
Train a BPE tokenizer with vocab size 2^14 = 16,384 on sampled Python files.
StarCoder-style tokenizer with special tokens for multi-file repos.
"""

from tokenizers import Tokenizer, models, trainers, pre_tokenizers, normalizers
from tokenizers.processors import TemplateProcessing
from pathlib import Path
from typing import List
import re
from tqdm import tqdm

try:
    from langdetect import detect, LangDetectException
    HAS_LANGDETECT = True
except ImportError:
    HAS_LANGDETECT = False
    print("⚠ Warning: langdetect not installed. Install with: pip install langdetect")
    print("  Skipping language filtering for now.\n")

# Configuration
VOCAB_SIZE = 2**16  # 16,384
INPUT_DIR = Path("sample_files")
OUTPUT_PATH = Path(f"tokenizer_{VOCAB_SIZE // 2**10}k.json")
FILTER_NON_ENGLISH = False  # Set to False to include all languages

# StarCoder-style special tokens
SPECIAL_TOKENS = [
    "<|endoftext|>",    # End of sequence/repository
    "<reponame>",       # Repository name
    "<filename>",       # File path
    "<gh_stars>",       # GitHub stars (quality signal)
    "<fim_prefix>",     # Fill-in-middle: prefix
    "<fim_middle>",     # Fill-in-middle: middle
    "<fim_suffix>",     # Fill-in-middle: suffix
    "<fim_pad>",        # Fill-in-middle: padding
]


def is_english_code(file_path: Path) -> bool:
    """
    Detect if code file is primarily English.
    Extracts strings and comments, then uses langdetect.
    Falls back to ASCII heuristic if langdetect unavailable.
    """
    if not HAS_LANGDETECT:
        return True  # Skip filtering if langdetect not available

    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()

        # Extract strings and comments for language detection
        # (code keywords would bias toward English)
        strings = re.findall(r'["\']([^"\']{10,})["\']', content)
        comments = re.findall(r'#\s*(.+)', content)
        text_content = ' '.join(strings + comments)

        # Need reasonable amount of text to detect
        if len(text_content) < 50:
            return True  # Too little text, assume OK

        # Detect language
        lang = detect(text_content)
        return lang == 'en'

    except (LangDetectException, Exception):
        # If detection fails, include the file
        return True


def get_training_files() -> List[str]:
    """Get list of training files, optionally filtered for English."""
    if not INPUT_DIR.exists():
        raise FileNotFoundError(
            f"Input directory not found: {INPUT_DIR}\n"
            f"Run fetch_sample_files_s3.py first"
        )

    files = list(INPUT_DIR.glob("file_*"))
    if not files:
        raise FileNotFoundError(f"No files found in {INPUT_DIR}")

    # Filter for English if enabled
    if FILTER_NON_ENGLISH and HAS_LANGDETECT:
        print(f"Filtering for English-only files...")
        english_files = []
        non_english_files = []

        for f in tqdm(files, desc="Detecting language", unit="file"):
            if is_english_code(f):
                english_files.append(f)
            else:
                non_english_files.append(f)

        print(f"  English files: {len(english_files)}")
        print(f"  Non-English files (excluded): {len(non_english_files)}")

        if non_english_files and len(non_english_files) <= 10:
            print(f"  Excluded files: {[f.name for f in non_english_files[:10]]}")

        files = english_files

    return [str(f) for f in files]


def train_tokenizer():
    """Train a BPE tokenizer."""
    print(f"Training tokenizer with vocab size {VOCAB_SIZE:,}")

    # Initialize BPE tokenizer
    tokenizer = Tokenizer(models.BPE())

    # Normalizer (minimal for code - preserve structure)
    tokenizer.normalizer = normalizers.Sequence([
        # Don't lowercase or strip - preserve code structure
    ])

    # Pre-tokenizer (split on whitespace and punctuation)
    tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)

    # Trainer
    trainer = trainers.BpeTrainer(
        vocab_size=VOCAB_SIZE,
        special_tokens=SPECIAL_TOKENS,
        show_progress=True,
        initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
    )

    # Get training files
    training_files = get_training_files()
    print(f"Training on {len(training_files)} files from {INPUT_DIR}")

    # Train
    tokenizer.train(files=training_files, trainer=trainer)

    # Note: No post-processor - we'll manually add <|endoftext|> when concatenating repos
    # This allows us to encode individual files without automatic EOS appending

    # Save
    tokenizer.save(str(OUTPUT_PATH))
    print(f"\n✓ Tokenizer saved to: {OUTPUT_PATH}")

    # Show some stats
    print(f"\n{'='*60}")
    print(f"Tokenizer Statistics:")
    print(f"{'='*60}")
    print(f"Vocabulary size: {tokenizer.get_vocab_size():,}")
    print(f"Special tokens: {len(SPECIAL_TOKENS)}")
    print(f"\nSpecial tokens:")
    for token in SPECIAL_TOKENS:
        token_id = tokenizer.token_to_id(token)
        print(f"  {token:<20} → ID {token_id}")

    # Evaluate encoding efficiency
    print(f"\n{'='*60}")
    print(f"Encoding Efficiency Analysis:")
    print(f"{'='*60}")

    # 1. Compression ratio - compare train vs test to detect overfitting
    print(f"\n1. Compression Ratio (characters per token):")

    # Evaluate on training data
    print(f"\n   Training Set:")
    train_total_chars = 0
    train_total_tokens = 0
    num_train_files = min(len(training_files), 100)

    for file_path in training_files[:num_train_files]:
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                text = f.read()
                train_total_chars += len(text)
                train_total_tokens += len(tokenizer.encode(text).ids)
        except Exception as e:
            print(f"   Warning: Could not read {file_path}: {e}")
            continue

    train_chars_per_token = train_total_chars / train_total_tokens if train_total_tokens > 0 else 0
    print(f"     Files: {num_train_files}")
    print(f"     Compression: {train_chars_per_token:.2f} chars/token")

    # Evaluate on test data
    print(f"\n   Test Set (held-out):")
    test_dir = Path("test_files")

    if test_dir.exists() and list(test_dir.glob("file_*")):
        test_files = list(test_dir.glob("file_*"))

        # Filter for English if enabled
        if FILTER_NON_ENGLISH and HAS_LANGDETECT:
            original_count = len(test_files)
            test_files = [f for f in test_files if is_english_code(f)]
            filtered_count = original_count - len(test_files)
            if filtered_count > 0:
                print(f"     Filtered out {filtered_count} non-English test files")

        test_total_chars = 0
        test_total_tokens = 0
        num_test_files = min(len(test_files), 100)

        for file_path in test_files[:num_test_files]:
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    text = f.read()
                    test_total_chars += len(text)
                    test_total_tokens += len(tokenizer.encode(text).ids)
            except Exception as e:
                print(f"   Warning: Could not read {file_path}: {e}")
                continue

        test_chars_per_token = test_total_chars / test_total_tokens if test_total_tokens > 0 else 0
        print(f"     Files: {num_test_files}")
        print(f"     Compression: {test_chars_per_token:.2f} chars/token")

        # Compare train vs test
        print(f"\n   Comparison:")
        if test_chars_per_token > 0:
            gap = train_chars_per_token - test_chars_per_token
            gap_pct = (gap / test_chars_per_token) * 100
            print(f"     Train - Test gap: {gap:+.2f} chars/token ({gap_pct:+.1f}%)")

            print(f"\n   Interpretation:")
            if abs(gap_pct) < 5:
                print(f"     ✓ Excellent generalization (gap < 5%)")
                if train_chars_per_token >= 3.5:
                    print(f"       Both train and test have good compression")
                elif train_chars_per_token >= 2.5:
                    print(f"       ⚠ Consider larger vocab or more training data")
                else:
                    print(f"       ✗ Insufficient vocab size - need larger vocab")
            elif gap_pct > 0 and gap_pct < 15:
                print(f"     ⚠ Slight overfitting (5-15% gap)")
                print(f"       Consider more diverse training data")
            elif gap_pct >= 15:
                print(f"     ✗ Overfitting detected (gap ≥ 15%)")
                print(f"       Vocab may be too large or training data not diverse enough")
            elif gap_pct < -5:
                print(f"     ⚠ Unusual: Test better than train (may indicate data issues)")
        else:
            print(f"     Could not compute comparison")
    else:
        print(f"     ⚠ No test files found")
        print(f"     Run: python fetch_test_files_s3.py")
        print(f"     Cannot assess overfitting without test set!")
        test_chars_per_token = 0

    # Overall quality assessment
    print(f"\n   Overall Quality:")
    avg_compression = (train_chars_per_token + test_chars_per_token) / 2 if test_chars_per_token > 0 else train_chars_per_token
    if avg_compression >= 3.5:
        print(f"     ✓ Good compression ({avg_compression:.2f} chars/token)")
    elif avg_compression >= 2.5:
        print(f"     ○ Fair compression ({avg_compression:.2f} chars/token)")
    else:
        print(f"     ✗ Poor compression ({avg_compression:.2f} chars/token)")

    # 1b. Outlier analysis - find hardest/easiest to encode files
    print(f"\n   Outlier Analysis:")
    print(f"   Examining individual file compression to detect heavy tail...")

    file_compressions = []
    errors = []
    for file_path in training_files[:num_train_files]:
        try:
            # Ensure file_path is a Path object
            file_path = Path(file_path)
            with open(file_path, 'r', encoding='utf-8') as f:
                text = f.read()
                if len(text) > 0:
                    tokens = len(tokenizer.encode(text).ids)
                    compression = len(text) / tokens
                    file_compressions.append((compression, file_path.name, len(text)))
        except Exception as e:
            errors.append((str(file_path), str(e)))
            continue

    if file_compressions:
        file_compressions.sort()  # Worst compression first

        compressions_only = [c for c, _, _ in file_compressions]
        import statistics
        median_comp = statistics.median(compressions_only)
        min_comp = min(compressions_only)
        max_comp = max(compressions_only)

        print(f"\n     Statistics across {len(file_compressions)} training files:")
        print(f"       Min:    {min_comp:.2f} chars/token (hardest)")
        print(f"       Median: {median_comp:.2f} chars/token")
        print(f"       Mean:   {train_chars_per_token:.2f} chars/token")
        print(f"       Max:    {max_comp:.2f} chars/token (easiest)")

        print(f"\n     5 Hardest to Encode (potential outliers):")
        for i, (comp, name, size) in enumerate(file_compressions[:5], 1):
            print(f"       {i}. {comp:.2f} chars/token - {name} ({size:,} chars)")

        print(f"\n     5 Easiest to Encode:")
        for i, (comp, name, size) in enumerate(file_compressions[-5:], 1):
            print(f"       {i}. {comp:.2f} chars/token - {name} ({size:,} chars)")

        # Check for heavy tail
        q1 = statistics.quantiles(compressions_only, n=4)[0]  # 25th percentile
        q3 = statistics.quantiles(compressions_only, n=4)[2]  # 75th percentile
        iqr = q3 - q1
        outlier_threshold = q1 - 1.5 * iqr
        outliers = [c for c in compressions_only if c < outlier_threshold]

        if len(outliers) > len(compressions_only) * 0.1:  # >10% outliers
            print(f"\n     ⚠ Heavy tail detected: {len(outliers)} outliers ({len(outliers)/len(compressions_only)*100:.1f}%)")
            print(f"       These hard-to-encode files are pulling down the average")
        elif len(outliers) > 0:
            print(f"\n     Minor tail: {len(outliers)} statistical outliers ({len(outliers)/len(compressions_only)*100:.1f}%)")
        else:
            print(f"\n     ✓ No significant outliers detected")
    else:
        print(f"\n     ⚠ Could not analyze files")
        print(f"       Processed files: 0/{num_train_files}")
        if errors:
            print(f"       Errors encountered: {len(errors)}")
            print(f"       First few errors:")
            for path, error in errors[:3]:
                print(f"         {path}: {error}")

    # 2. Common Python keywords efficiency
    print(f"\n2. Common Python Keywords (should be 1 token each):")
    keywords = [
        "def", "class", "import", "from", "return",
        "if", "else", "elif", "for", "while",
        "try", "except", "finally", "with", "as",
        "self", "None", "True", "False", "pass",
        "lambda", "yield", "async", "await"
    ]

    keyword_efficiency = []
    for keyword in keywords:
        ids = tokenizer.encode(keyword).ids
        tokens = tokenizer.encode(keyword).tokens
        is_single = len(ids) == 1
        keyword_efficiency.append(is_single)
        symbol = "✓" if is_single else "✗"
        if not is_single:
            print(f"   {symbol} {keyword:10} → {len(ids)} tokens {tokens}")

    single_token_pct = sum(keyword_efficiency) / len(keyword_efficiency) * 100
    print(f"\n   Single-token keywords: {single_token_pct:.0f}% ({sum(keyword_efficiency)}/{len(keyword_efficiency)})")

    # 3. Common patterns
    print(f"\n3. Common Code Patterns:")
    patterns = [
        "def __init__(self):",
        "if __name__ == '__main__':",
        "import numpy as np",
        "from typing import",
        "return None"
    ]

    for pattern in patterns:
        encoded = tokenizer.encode(pattern)
        chars_per_token = len(pattern) / len(encoded.ids)
        print(f"   '{pattern}'")
        print(f"      → {len(encoded.ids)} tokens ({chars_per_token:.2f} chars/token)")

    # 4. Test encoding example
    print(f"\n4. Example Encoding:")
    test_text = """<reponame>user/project<filename>main.py
def hello():
    print("Hello, world!")
<|endoftext|>"""

    encoding = tokenizer.encode(test_text)
    print(f"Input ({len(test_text)} chars):\n{test_text}")
    print(f"\nEncoded to {len(encoding.ids)} tokens ({len(test_text)/len(encoding.ids):.2f} chars/token):")
    print(f"Tokens: {encoding.tokens}")

    # 5. Vocabulary utilization
    print(f"\n5. Vocabulary Statistics:")
    print(f"   Total vocab size: {tokenizer.get_vocab_size():,}")
    print(f"   Special tokens: {len(SPECIAL_TOKENS)}")
    print(f"   Learned tokens: {tokenizer.get_vocab_size() - len(SPECIAL_TOKENS):,}")

    print(f"\n{'='*60}")
    print(f"Summary:")
    print(f"{'='*60}")

    # Overall assessment based on multiple factors
    avg_compression = (train_chars_per_token + test_chars_per_token) / 2 if test_chars_per_token > 0 else train_chars_per_token
    gap_pct = ((train_chars_per_token - test_chars_per_token) / test_chars_per_token * 100) if test_chars_per_token > 0 else 0

    if avg_compression >= 3.5 and single_token_pct >= 80 and abs(gap_pct) < 10:
        print(f"✓ Excellent tokenizer quality!")
        print(f"  - Good compression on both train and test")
        print(f"  - Keywords encoded efficiently")
        print(f"  - Minimal overfitting")
    elif avg_compression >= 2.5 and single_token_pct >= 60:
        print(f"○ Fair tokenizer quality")
        if abs(gap_pct) >= 15:
            print(f"  ⚠ Overfitting detected - need more diverse training data")
        elif avg_compression < 3.0:
            print(f"  ⚠ Low compression - consider larger vocab or more training data")
        else:
            print(f"  ⚠ Consider more training data for better quality")
    else:
        print(f"✗ Poor tokenizer quality")
        if avg_compression < 2.5:
            print(f"  - Insufficient vocab size OR need much more training data")
        if single_token_pct < 60:
            print(f"  - Common keywords not learned efficiently")

    print(f"\nMetrics:")
    print(f"  Train compression: {train_chars_per_token:.2f} chars/token")
    if test_chars_per_token > 0:
        print(f"  Test compression:  {test_chars_per_token:.2f} chars/token")
        print(f"  Train-test gap:    {gap_pct:+.1f}%")
    print(f"  Keyword efficiency: {single_token_pct:.0f}%")
    print(f"  Vocab size: {tokenizer.get_vocab_size():,}")


if __name__ == "__main__":
    train_tokenizer()
