from helper import utils
from helper.thirdparty.tofu.dataloader import custom_data_collator_forget, CustomTrainerForgetting

def tofu_baseline(model, oracle_model, tokenizer, train_data, config, training_args, unlearn_method):
    trainer = CustomTrainerForgetting(
            model=model,
            tokenizer=tokenizer,
            train_dataset=train_data,
            compute_metrics=None,   # the callback for computing metrics, None in this case since you're doing it in your callback
            args=training_args,
            data_collator=custom_data_collator_forget,
            oracle_model=oracle_model,
            forget_loss=unlearn_method,
            eval_cfg=None,      # turn off evaluate during unlearning
        )
    
    trainer.train()
    