from trl import SFTTrainer, SFTConfig
import json
import os
from transformers import TrainerCallback

class MetricsCallback(TrainerCallback):
    """Custom callback to save training metrics to JSON file"""
    def __init__(self, output_dir):
        self.output_dir = output_dir
        self.metrics_history = []
        
    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        """Called after logging"""
        if logs is not None:
            # Add step number and epoch info
            log_entry = {
                "step": state.global_step,
                "epoch": state.epoch,
                **logs
            }
            self.metrics_history.append(log_entry)
            
            # Save to JSON file every 10 steps
            if state.global_step % 10 == 0:
                metrics_file = os.path.join(self.output_dir, "training_metrics.json")
                with open(metrics_file, 'w') as f:
                    json.dump(self.metrics_history, f, indent=2)
                    
    def on_train_end(self, args, state, control, **kwargs):
        """Save final metrics at the end of training"""
        metrics_file = os.path.join(self.output_dir, "training_metrics.json")
        with open(metrics_file, 'w') as f:
            json.dump(self.metrics_history, f, indent=2)
        print(f"Training metrics saved to: {metrics_file}")

def create_training_args(config):
    return SFTConfig(
        # Basic training parameters
        output_dir=config['training']['output_dir'],
        num_train_epochs=config['training']['num_train_epochs'],
        per_device_train_batch_size=config['training']['per_device_train_batch_size'],
        per_device_eval_batch_size=config['training']['per_device_eval_batch_size'],
        gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
        
        # Optimizer parameters
        optim=config['training']['optim'],
        adam_beta2=config['training']['adam_beta2'],
        adam_epsilon=config['training']['adam_epsilon'],
        max_grad_norm=config['training']['max_grad_norm'],
        lr_scheduler_type=config['training']['lr_scheduler_type'],
        learning_rate=config['training']['learning_rate'],
        
        # Logging and saving
        save_steps=config['training']['save_steps'],
        logging_steps=config['training']['logging_steps'],
        
        # Training optimizations
        group_by_length=config['training']['group_by_length'],
        fp16=config['training']['fp16'],
        bf16=config['training']['bf16'],
        warmup_steps=config['training']['warmup_steps'],
        weight_decay=config['training']['weight_decay'],
        
        # SFT-specific parameters
        max_length=config['training'].get('max_length', 2048),
        packing=config['training'].get('packing', False),
        
        # Enhanced logging configuration
        report_to=["tensorboard"],  # Enable TensorBoard logging
        logging_dir=f"{config['training']['output_dir']}/logs",  # Where to save logs
        logging_first_step=True,  # Log the first step
        eval_strategy="steps" if config['training']['save_steps'] > 0 else "epoch",  # When to evaluate
        eval_steps=config['training']['logging_steps'],  # Evaluate every N steps
        save_strategy="steps" if config['training']['save_steps'] > 0 else "epoch",  # When to save
    )

def train_model(model, tokenizer, train_dataset, eval_dataset, config):
    """
    Train model using SFTTrainer with enhanced logging.
    Saves metrics to both TensorBoard and JSON file.
    """
    training_args = create_training_args(config)
    
    # Create custom callback for JSON logging
    metrics_callback = MetricsCallback(config['training']['output_dir'])
    
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        processing_class=tokenizer, # tokenizer is also the processing class
        callbacks=[metrics_callback],  # Add custom callback for JSON logging
    )
    
    trainer.train()
    
    # Save the model and tokenizer
    trainer.model.save_pretrained(config['model']['new_model'])
    trainer.processing_class.save_pretrained(config['model']['new_model'])
    
    print(f"Model saved to: {config['model']['new_model']}")
    print(f"TensorBoard logs saved to: {config['training']['output_dir']}/logs")