#!/usr/bin/env python3
"""
Example usage of the enhanced train_sft.py script

This script demonstrates how to use the new training script with different configurations.
"""

import subprocess
import sys
import os

def run_training_example():
    """Example of running the training script with different configurations"""
    
    # Example 1: Basic SFT training with dataset splitting and testing
    print("=== Example 1: Basic SFT Training with Testing ===")
    cmd1 = [
        "python", "train_sft.py",
        "--dataset_path", "/path/to/your/dataset.json",
        "--split_dataset",
        "--test_size", "200",
        "--task_type", "readmission",
        "--sft_epochs", "2",
        "--sft_lr", "5e-5",
        "--lora_r", "32",
        "--batch_size", "2",
        "--gradient_accumulation_steps", "4",
        "--test_model"
    ]
    print("Command:", " ".join(cmd1))
    print()
    
    # Example 2: SFT + DPO training with custom hyperparameters
    print("=== Example 2: SFT + DPO Training ===")
    cmd2 = [
        "python", "train_sft.py",
        "--dataset_path", "/path/to/your/sft_dataset.json",
        "--dpo_dataset_path", "/path/to/your/dpo_dataset.json",
        "--split_dataset",
        "--test_size", "150",
        "--task_type", "mortality",
        "--enable_dpo",
        "--sft_epochs", "1",
        "--dpo_epochs", "1",
        "--sft_lr", "3e-5",
        "--dpo_lr", "3e-5",
        "--lora_r", "16",
        "--lora_dropout", "0.1",
        "--batch_size", "1",
        "--gradient_accumulation_steps", "8",
        "--test_model"
    ]
    print("Command:", " ".join(cmd2))
    print()
    
    # Example 3: High-performance training with larger LoRA
    print("=== Example 3: High-Performance Training ===")
    cmd3 = [
        "python", "train_sft.py",
        "--dataset_path", "/path/to/your/dataset.json",
        "--split_dataset",
        "--test_size", "100",
        "--task_type", "period",
        "--sft_epochs", "3",
        "--sft_lr", "2e-5",
        "--lora_r", "64",
        "--lora_dropout", "0.05",
        "--batch_size", "4",
        "--gradient_accumulation_steps", "2",
        "--test_model"
    ]
    print("Command:", " ".join(cmd3))
    print()

def print_help():
    """Print help information"""
    print("""
Enhanced QWEN Training Script Usage Guide
========================================

The enhanced train_sft.py script now supports:

1. Dataset Selection and Splitting:
   --dataset_path: Path to your training dataset JSON file
   --split_dataset: Automatically split dataset into train/test sets
   --test_size: Number of samples for testing (default: 200)

2. Hyperparameter Configuration:
   --lora_r: LoRA rank (default: 16)
   --lora_dropout: LoRA dropout rate (default: 0.05)
   --sft_lr: SFT learning rate (default: 4e-5)
   --sft_epochs: Number of SFT epochs (default: 1)
   --batch_size: Per device batch size (default: 1)
   --gradient_accumulation_steps: Gradient accumulation steps (default: 8)

3. DPO Training (Optional):
   --enable_dpo: Enable DPO training after SFT
   --dpo_dataset_path: Path to DPO dataset
   --dpo_lr: DPO learning rate (default: 4e-5)
   --dpo_epochs: Number of DPO epochs (default: 1)

4. Model Testing:
   --test_model: Test the trained model on test set
   --task_type: Type of task (readmission, mortality, period)

Key Features:
- Automatic dataset splitting (first N samples become test set)
- Integrated model testing using the same logic as qwen_inference.py
- Flexible hyperparameter configuration
- Support for both SFT-only and SFT+DPO training
- Automatic result saving and metrics calculation

Output Structure:
- Trained models: ./trained_models/{dataset_name}_Qwen3-0.6B-SFT_r{rank}_lr{lr}_ep{epochs}_bs{batch}_gas{gas}/
- Test data: ./test_data/{dataset_name}_test.json
- Test results: ./test_data/{dataset_name}_test_results_r{rank}_lr{lr}_ep{epochs}_bs{batch}_gas{gas}.json

Example: clinical_data_Qwen3-0.6B-SFT_r16_lr4e-5_ep2_bs1_gas8/

Example Usage:
python train_sft.py --dataset_path data.json --split_dataset --test_model --task_type readmission
""")

if __name__ == "__main__":
    if len(sys.argv) > 1 and sys.argv[1] == "--help":
        print_help()
    else:
        run_training_example()
        print_help()
