#!/usr/bin/env python3
"""
Step 1b: Save missing files (vectorized, fast)

Converts teacher tokens to distilled format and saves:
- {split}_distilled_input_ids.npy (with MASK tokens)
- {split}_distilled_labels.npy
- {split}_attention_mask.npy
"""

import sys
sys.path.append('..')

import json
import argparse
import numpy as np
from pathlib import Path

from config import get_model_paths, RESULTS_ROOT
from distill_config import DISTILL_CONFIG


def build_token_mapping(teacher_vocab, distilled_vocab):
    """Build lookup table for fast conversion."""
    teacher_id_to_name = {v: k for k, v in teacher_vocab.items()}

    # Mapping table: teacher_token_id -> distilled_token_id
    mapping = np.zeros(len(teacher_vocab), dtype=np.int32)

    for tid, token_name in teacher_id_to_name.items():
        if token_name == '<PAD>':
            mapping[tid] = distilled_vocab['<PAD>']
        elif token_name.startswith('N'):
            mapping[tid] = distilled_vocab['N']
        elif token_name.startswith('A'):
            mapping[tid] = distilled_vocab['A']
        elif token_name.startswith('T'):
            mapping[tid] = distilled_vocab['T']
        elif token_name.startswith('G'):
            mapping[tid] = distilled_vocab['G']
        elif token_name.startswith('C'):
            mapping[tid] = distilled_vocab['C']
        else:
            mapping[tid] = distilled_vocab['<PAD>']

    return mapping


def process_split(split, teacher_paths, token_mapping, distilled_vocab, output_dir):
    """Process a single split (vectorized)"""
    print(f"\n{'='*60}")
    print(f"Processing {split}")
    print(f"{'='*60}")

    # Load teacher data
    print(f"Loading teacher {split} data...")
    teacher_input_ids = np.load(f"{teacher_paths['processed_dir']}/{split}_input_ids.npy")
    teacher_labels = np.load(f"{teacher_paths['processed_dir']}/{split}_labels.npy")
    teacher_attention_mask = np.load(f"{teacher_paths['processed_dir']}/{split}_attention_mask.npy")

    num_samples, seq_len = teacher_input_ids.shape
    print(f"Samples: {num_samples:,}, Seq length: {seq_len}")

    # Vectorized conversion
    print("Converting tokens (vectorized)...")

    # Convert input_ids using lookup table
    distilled_input_ids = token_mapping[teacher_input_ids]

    # Set MASK token where labels != -100
    mask_positions = (teacher_labels != -100)
    distilled_input_ids[mask_positions] = distilled_vocab['<MASK>']

    # Convert labels (same mapping, keep -100)
    distilled_labels = np.where(
        teacher_labels == -100,
        -100,
        token_mapping[teacher_labels]
    ).astype(np.int32)

    # Save
    print("Saving...")
    np.save(output_dir / f'{split}_distilled_input_ids.npy', distilled_input_ids)
    np.save(output_dir / f'{split}_distilled_labels.npy', distilled_labels)
    np.save(output_dir / f'{split}_attention_mask.npy', teacher_attention_mask)

    print(f"Saved:")
    print(f"  {split}_distilled_input_ids.npy")
    print(f"  {split}_distilled_labels.npy")
    print(f"  {split}_attention_mask.npy")

    # Verify
    print(f"Verification:")
    print(f"  input_ids range: [{distilled_input_ids.min()}, {distilled_input_ids.max()}]")
    mask_count = (distilled_input_ids == distilled_vocab['<MASK>']).sum()
    print(f"  MASK token (8) count: {mask_count:,}")
    mask_ratio = mask_count / distilled_input_ids.size * 100
    print(f"  MASK ratio: {mask_ratio:.1f}%")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--split', default='both', choices=['train', 'val', 'both'])
    parser.add_argument('--output_dir', default=str(RESULTS_ROOT / '6_distillation'), help='Output directory')
    args = parser.parse_args()

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"Output directory: {output_dir}")

    # Load vocabs
    teacher_paths = get_model_paths(DISTILL_CONFIG['teacher_model'])
    distilled_paths = get_model_paths(DISTILL_CONFIG['distilled_base'])

    with open(teacher_paths['vocab_file']) as f:
        teacher_vocab = json.load(f)
    with open(distilled_paths['vocab_file']) as f:
        distilled_vocab = json.load(f)

    print(f"Teacher vocab size: {len(teacher_vocab)}")
    print(f"Distilled vocab size: {len(distilled_vocab)}")

    # Build mapping table once
    print("Building token mapping...")
    token_mapping = build_token_mapping(teacher_vocab, distilled_vocab)

    # Process splits
    if args.split == 'both':
        process_split('train', teacher_paths, token_mapping, distilled_vocab, output_dir)
        process_split('val', teacher_paths, token_mapping, distilled_vocab, output_dir)
    else:
        process_split(args.split, teacher_paths, token_mapping, distilled_vocab, output_dir)

    print(f"\nDone!")


if __name__ == '__main__':
    main()
