#!/usr/bin/env python3
"""
Configuration for Knowledge Distillation (Embedding-based)
"""

import sys
sys.path.append('..')

from pathlib import Path
from config import RESULTS_ROOT, get_model_paths, get_model_config

# =============================================================================
# Distillation Paths
# =============================================================================
DISTILL_ROOT = RESULTS_ROOT / '6_distillation'

def get_distill_paths():
    return {
        'root': DISTILL_ROOT,
        'teacher_hidden': DISTILL_ROOT / 'teacher_hidden',
        'distilled_model': DISTILL_ROOT / 'distilled' / 'model',
        'results': DISTILL_ROOT / 'results',
    }

# =============================================================================
# Teacher Config (M3)
# =============================================================================
TEACHER_CONFIG = {
    'model_name': 'full',
    'd_model': 768,
    'nhead': 12,
    'num_layers': 12,
    'max_seq_len': 1002,
    # ~86M parameters
}

# =============================================================================
# Distilled Config (Smaller)
# =============================================================================
DISTILLED_CONFIG = {
    'd_model': 768,      # Keep same for MSE compatibility
    'nhead': 4,          # 12 → 4
    'num_layers': 4,     # 12 → 4
    'max_seq_len': 1002,
    # ~28M parameters (약 1/3 크기)
}

# =============================================================================
# Distillation Hyperparameters
# =============================================================================
DISTILL_CONFIG = {
    'teacher_model': 'full',
    'distilled_base': 'seq',  # For vocab

    # Training
    'batch_size': 128,
    'learning_rate': 5e-5,   # Slightly higher for smaller model
    'epochs': 10,
    'warmup_steps': 1000,
    'gradient_clip': 1.0,

    # Distillation
    'alpha': 0.5,  # Weight for CE loss (1-alpha for MSE)
}
