from helper import utils
from helper.thirdparty.tofu.dataloader import custom_data_collator_forget, CustomTrainerForgetting

def scrub_tofu(model, oracle_model, tokenizer, train_data, config, training_args):
    training_args.learning_rate = 1e-4
    trainer = CustomTrainerForgetting(
            model=model,
            tokenizer=tokenizer,
            train_dataset=train_data,
            compute_metrics=None,   # the callback for computing metrics, None in this case since you're doing it in your callback
            args=training_args,
            data_collator=custom_data_collator_forget,
            oracle_model=oracle_model,
            forget_loss='scrub',
            eval_cfg=None,      # turn off evaluate during unlearning
        )
    
    for epoch in range(config.num_total_epochs):
        trainer.loss_type = 'scrub_maximize'
        trainer.train()
        utils.clear_cache()

        trainer.loss_type = 'scrub_minimize'
        trainer.train()
        utils.clear_cache()
