from src.unlearn.finetune import train_classification_model


def train_unlearn(
    model,
    train_loader,
    device,
    lr,
    alpha: float = 0.1,
    num_epochs: int = 10000,
    val_loader=None,
    loss_fn=None,
    lr_scheduler: None | str = None,
    patience: int = 5,
    weight_decay: float = 1e-6,
    save_path: None | str = None,
):
    train_classification_model(
        model=model,
        train_loader=train_loader,
        loss_fn=loss_fn,
        device=device,
        lr=lr,
        num_epochs=num_epochs,
        use_l1_reg=True,
        alpha=alpha,
        val_loader=val_loader,
        lr_scheduler=lr_scheduler,
        patience=patience,
        weight_decay=weight_decay,
        save_path=save_path,
    )
