import torch
from tqdm import tqdm
from torch.utils.data import DataLoader

from helper.thirdparty.tofu.dataloader import custom_data_collator_forget, CustomTrainerForgetting


def gdiff_tofu(model, tokenizer, train_data, config, training_args):
    trainer = CustomTrainerForgetting(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_data,
        compute_metrics=None,
        args=training_args,
        data_collator=custom_data_collator_forget,
        oracle_model=None,
        forget_loss='grad_diff',
        eval_cfg=None,      # turn off evaluate during unlearning
    )
    trainer.train()

def gdiff(
    model, forget_set, retain_set, config,
    trainer_init_func=None,
    trainer_init_kwargs=None,
    device=None, 
    unl_logs=None,
):
    forget_loader = DataLoader(forget_set, shuffle=True, batch_size=config.train_batch_size)
    retain_loader = DataLoader(retain_set, shuffle=True, batch_size=config.train_batch_size)
    criterion = getattr(torch.nn, config.loss)()
    optimizer_cls = getattr(torch.optim, config.optimizer)
    optimizer = optimizer_cls(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

    model.train()
    forget_iter = iter(forget_loader)
    retain_iter = iter(retain_loader)
    steps_per_epoch = min(len(forget_loader), len(retain_loader))
    for epoch in tqdm(range(config.num_epochs), "GDiff"):
        for i in range(steps_per_epoch):
            try:
                forget_batch = next(forget_iter)
            except StopIteration:
                forget_iter = iter(forget_loader)
                forget_batch = next(forget_iter)

            try:
                retain_batch = next(retain_iter)
            except StopIteration:
                retain_iter = iter(retain_loader)
                retain_batch = next(retain_iter)
        
            optimizer.zero_grad()
            forget_inputs, forget_labels = forget_batch
            forget_outputs = model(forget_inputs.to(device))
            forget_loss = (-1) * criterion(forget_outputs, forget_labels.to(device))

            retain_inputs, retain_labels = retain_batch
            retain_outputs = model(retain_inputs.to(device))
            retain_loss = criterion(retain_outputs, retain_labels.to(device))

            gdiff_loss = forget_loss + retain_loss
            gdiff_loss.backward()
            optimizer.step()
