#!/usr/bin/env python3
"""
Test script to demonstrate unified training for both baseline and Rosetta modes.
The training mode is automatically detected from the config file.
"""

import json
import os
import tempfile
import subprocess

def create_baseline_config():
    """Create a sample baseline training config"""
    config = {
        "model": {
            "baseline_model": "Qwen/Qwen3-0.6B"
        },
        "training": {
            "learning_rate": 1e-5,
            "weight_decay": 0.01,
            "num_epochs": 1,
            "max_length": 2048,
            "device": "cuda",
            "scheduler_type": "linear",
            "warmup_ratio": 0.1,
            "max_grad_norm": 1.0,
            "per_device_train_batch_size": 4,
            "freeze": [],
            "seed": 42,
            # Optional: Enable LoRA training
            "lora": {
                "r": 8,
                "lora_alpha": 16,
                "lora_dropout": 0.1
            }
        },
        "output": {
            "output_dir": "outputs/baseline_test",
            "save_steps": 100,
            "eval_steps": 50,
            "wandb_config": {
                "project": "unified_training",
                "mode": "offline",
                "run_name": "baseline_test"
            }
        },
        "data": {
            "type": "MMLUChatDataset",
            "kwargs": {
                "split": "test",
                "num_samples": 100,
                "max_word_count": 256
            },
            "train_ratio": 0.9
        }
    }
    return config

def create_rosetta_config():
    """Create a sample Rosetta training config"""
    config = {
        "model": {
            "base_model": "Qwen/Qwen3-0.6B",
            "teacher_model": "Qwen/Qwen3-4B",
            "include_response": False,
            "projector": {
                "type": "AdditiveProjector",
                "params": {
                    "hidden_dim": 1024,
                    "num_layers": 3,
                    "dropout": 0.1,
                    "activation": "gelu",
                    "use_layer_norm": True,
                    "init_weight": 0.0,
                    "anneal_steps": 100
                }
            },
            "mapping": "last_aligned",
            "aggregator": {
                "type": "WeightedAggregator",
                "params": {
                    "num_options": 3,
                    "initial_temperature": 1.0,
                    "final_temperature": 0.0001,
                    "anneal_steps": 100
                }
            }
        },
        "training": {
            "learning_rate": 3e-4,
            "weight_decay": 0.01,
            "num_epochs": 1,
            "max_length": 2048,
            "device": "cuda",
            "scheduler_type": "linear",
            "warmup_ratio": 0.1,
            "max_grad_norm": 1.0,
            "per_device_train_batch_size": 2,
            "freeze": ["teacher", "base"],
            "seed": 42
        },
        "output": {
            "output_dir": "outputs/rosetta_test",
            "save_steps": 100,
            "eval_steps": 50,
            "wandb_config": {
                "project": "unified_training",
                "mode": "offline",
                "run_name": "rosetta_test"
            }
        },
        "data": {
            "type": "MMLUChatDataset",
            "kwargs": {
                "split": "test",
                "num_samples": 100,
                "max_word_count": 256
            },
            "train_ratio": 0.9
        }
    }
    return config

def test_mode_detection():
    """Test that the training script correctly detects the mode"""
    print("Testing mode detection...")
    
    # Test baseline mode
    with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
        json.dump(create_baseline_config(), f)
        baseline_config_path = f.name
    
    # Test Rosetta mode
    with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
        json.dump(create_rosetta_config(), f)
        rosetta_config_path = f.name
    
    print(f"Created baseline config: {baseline_config_path}")
    print(f"Created Rosetta config: {rosetta_config_path}")
    
    # Commands to run
    print("\nTo test baseline training, run:")
    print(f"python script/train/SFT_train.py --config {baseline_config_path}")
    
    print("\nTo test Rosetta training, run:")
    print(f"python script/train/SFT_train.py --config {rosetta_config_path}")
    
    return baseline_config_path, rosetta_config_path

def main():
    print("=" * 60)
    print("Unified Training Script Test")
    print("=" * 60)
    
    # Create test configs
    baseline_config_path, rosetta_config_path = test_mode_detection()
    
    print("\n" + "=" * 60)
    print("Key differences between modes:")
    print("-" * 60)
    
    print("\nBaseline Mode (when 'baseline_model' is in config):")
    print("  - Trains a single model")
    print("  - Supports LoRA and partial parameter training")
    print("  - Uses BaselineChatDataset and BaselineDataCollator")
    print("  - Simpler forward pass without KV cache management")
    
    print("\nRosetta Mode (when 'base_model' and 'teacher_model' are in config):")
    print("  - Trains with knowledge distillation from teacher to student")
    print("  - Uses projectors and aggregators")
    print("  - Supports token alignment between different tokenizers")
    print("  - Uses RosettaDataCollator with KV cache indexing")
    
    print("\n" + "=" * 60)
    print("The training mode is automatically detected from the config.")
    print("No code changes needed - just use different config files!")
    print("=" * 60)

if __name__ == "__main__":
    main()
