from trl import DPOConfig, SFTConfig, SFTTrainer
from utils.dpo_trainer import DPOTrainer

def get_trainer_config(trainer_name, model_dir, args):
    if trainer_name == "sft":
        trainer =  SFTTrainer
        config = SFTConfig(
            output_dir=model_dir,
            num_train_epochs=args.epochs,
            per_device_train_batch_size=args.batch_size,
            learning_rate=args.lr,
            max_seq_length=args.max_seq_len,
            logging_steps=10,
            max_steps=args.max_steps,
            save_strategy="no",
            lr_scheduler_type="linear",
            warmup_ratio=args.warmup_ratio,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            resume_from_checkpoint=False,
            remove_unused_columns=False,
            seed=args.seed,
            bf16=True
        )
    elif trainer_name == "dpo":
        trainer = DPOTrainer
        config = DPOConfig(
           output_dir=model_dir,
            num_train_epochs=args.epochs,
            per_device_train_batch_size=args.batch_size,
            learning_rate=args.lr,
            logging_steps=10,
            max_steps=args.max_steps,
            save_strategy="no",
            lr_scheduler_type="linear",
            warmup_ratio=args.warmup_ratio,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            resume_from_checkpoint=False,
            remove_unused_columns=False,
            seed=args.seed,
            bf16=True
        )
    else:
        raise ValueError(f"Invalid trainer: {trainer_name}")
    
    return trainer, config