import torch

from helper import utils
from settings import train_func

def sgd(
    model, forget_set, retain_set, config,
    trainer_init_func=None,
    trainer_init_kwargs=None,
    device=None, 
    unl_logs=None,
):

    if config.llama:
        raise NotImplementedError("SGD implementation coincides with GD :)")

    else:
        sgd_optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
        retain_loader = utils.get_dataloader(retain_set, shuffle=True, batch_size=config.train_batch_size)
        criterion = getattr(torch.nn, config.loss)()
        lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(sgd_optimizer, lambda step: 0.5 if step % config.lr_update_interval == 0 else 1.0)

        train_func.train(
            model, retain_loader, criterion, sgd_optimizer,
            eval_loader=None,
            num_epochs=config.num_epochs, 
            log_frequency=config.log_frequency, 
            lr_scheduler=lr_scheduler, 
            device=device,
        )