import torch
from torch.utils.data import DataLoader

from settings import train_func


def retrain(
    model, forget_set, retain_set, config,
    trainer_init_func=None,
    trainer_init_kwargs=None,
    device=None, 
    unl_logs=None,
):
    if config.llama:
        model.config.use_cache = False
        trainer_init_kwargs.model = model
        trainer_init_kwargs.train_dataset = retain_set
        trainer = trainer_init_func(**vars(trainer_init_kwargs))
        trainer.train()
        model = trainer.model
    else:
        retain_loader = DataLoader(retain_set, shuffle=True, batch_size=config.train_batch_size)
        criterion = getattr(torch.nn, config.loss)()
        optimizer_cls = getattr(torch.optim, config.optimizer)
        optimizer = optimizer_cls(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lambda step: 0.5 if step % config.lr_update_interval == 0 else 1.0)
        train_func.train(
            model, retain_loader, criterion, optimizer, 
            eval_loader=None, 
            num_epochs=config.num_epochs, 
            log_frequency=config.log_frequency, 
            lr_scheduler=lr_scheduler,
            device=device,
        )