import torch
from torch.utils.data import DataLoader

from settings import train_func

def gd(
    model, forget_set, retain_set, config,
    trainer_init_func=None,
    trainer_init_kwargs=None,
    device=None, 
    unl_logs=None,
):

    if config.llama:
        trainer_init_kwargs.model = model
        trainer_init_kwargs.train_dataset = retain_set
        trainer = trainer_init_func(**vars(trainer_init_kwargs))
        trainer.train()
        model = trainer.model
    else:
        gd_optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
        retain_loader = DataLoader(retain_set, shuffle=True, batch_size=len(retain_set))
        criterion = getattr(torch.nn, config.loss)()
        lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(gd_optimizer, lambda step: 0.5 if step % config.lr_update_interval == 0 else 1.0)

        train_func.train(
            model, retain_loader, criterion, gd_optimizer, 
            eval_loader=None, 
            num_epochs=config.num_epochs,
            log_frequency=config.log_frequency, 
            lr_scheduler=lr_scheduler, 
            device=device,
        )