from .retrain import retrain
from .impl import load_unlearn_checkpoint, save_unlearn_checkpoint
from .boundary_ex import boundary_expanding
from .boundary_sh import boundary_shrink


from .Con import CON

def raw(data_loaders, model, criterion, args, mask=None):
    pass


def get_unlearn_method(name):
    """method usage:

    function(data_loaders, model, criterion, args)"""
    
    if name == "retrain":
        return retrain
    elif name == "boundary_expanding":
        return boundary_expanding
    elif name == "boundary_shrink":
        return boundary_shrink
    elif name == "CON":
        return CON
    else:
        raise NotImplementedError(f"Unlearn method {name} not implemented!")
