from helper import utils
from helper.thirdparty.tofu.dataloader import custom_data_collator_forget, CustomTrainerForgetting

def idk(model, tokenizer, train_data, config, training_args):

    trainer = CustomTrainerForgetting(
            model=model,
            tokenizer=tokenizer,
            train_dataset=train_data,
            compute_metrics=None,   # the callback for computing metrics, None in this case since you're doing it in your callback
            args=training_args,
            data_collator=custom_data_collator_forget,
            oracle_model=None,
            forget_loss='idk',
            eval_cfg=None,      # turn off evaluate during unlearning
        )
    
    trainer.train()
