from .FT import FT_iter
import pruner
from .impl import iterative_unlearn

prune_step = 2

@iterative_unlearn
def FT_prune_bi(data_loaders, model, criterion, optimizer, epoch, args):
    # switch to train mode
    model.train()

    # prune
    prune_rate = 1 - (1 - args.rate) ** (1 / ((args.unlearn_epochs - 1) // prune_step + 1))

    if (args.unlearn_epochs - epoch) % prune_step == 0:
        if args.random_prune:
            print('random pruning')
            pruner.pruning_model_random(model, prune_rate)
        else:
            print('L1 pruning')
            pruner.pruning_model(model, prune_rate)

    pruner.check_sparsity(model)

    return FT_iter(data_loaders, model, criterion, optimizer, epoch, args)
