import torch
from functools import partial
from torch.utils.data import DataLoader

from settings import train_func, train_transformers_func
from helper import utils

def ga(
    model, forget_set, retain_set, config,
    trainer_init_func=None,
    trainer_init_kwargs=None,
    device=None, 
    unl_logs=None,
    ):
    
    if config.llama:
        trainer_init_kwargs.model = model
        trainer_init_kwargs.train_dataset = forget_set
        trainer = trainer_init_func(**vars(trainer_init_kwargs))
        TrainerBase = getattr(train_transformers_func, trainer.__class__.__name__)
        class NegatedLossTrainer:
            def compute_loss(self, *args, **kwargs):
                loss = TrainerBase.compute_loss(self, *args, **kwargs)
                if isinstance(loss, tuple):
                    loss, outputs = loss
                    loss = -loss
                    return (loss, outputs)
                else:
                    return -loss
        
        trainer.compute_loss = partial(NegatedLossTrainer.compute_loss, trainer)
        trainer.train()
        model = trainer.model
    
    else:
        forget_loader = DataLoader(forget_set, shuffle=True, batch_size=config.train_batch_size)
        criterion = getattr(torch.nn, config.loss)()
        negate_loss_fn = lambda *args, **kwargs: -criterion(*args, **kwargs)
        optimizer_cls = getattr(torch.optim, config.optimizer)
        optimizer = optimizer_cls(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

        lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
            optimizer, lambda step: 0.5 if step % config.lr_update_interval == 0 else 1.0,
        )

        train_func.train(
            model, forget_loader, negate_loss_fn, optimizer, 
            eval_loader=None, num_epochs=config.num_epochs, 
            log_frequency=config.log_frequency, lr_scheduler=lr_scheduler, 
            device=device,
        )
        