import copy
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm


def delete(
    model, forget_set, retain_set, config,
    trainer_init_func=None,
    trainer_init_kwargs=None,
    device=None, 
    unl_logs=None,
):
    # freeze the original model -- it will be used to generate new random labels
    orig_model = copy.deepcopy(model)
    orig_model.to(device)
    orig_model.eval()

    # objective is minimizing KL-divergence to the newly generated labels
    criterion = nn.KLDivLoss(reduction='batchmean')
    optimizer_cls = getattr(torch.optim, config.optimizer)
    optimizer = optimizer_cls(
        model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay,
    )

    # load the forget set
    forget_loader = DataLoader(forget_set, shuffle=True, batch_size=config.train_batch_size)

    # start unlearning
    for _ in tqdm(range(config.num_epochs), desc="Running DELETE"):
        for inputs, targets in forget_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            with torch.no_grad():
                orig_logits = orig_model(inputs)
            
            # mask the logits of the target class returned by the original model
            batch_size = inputs.shape[0]
            orig_logits[torch.arange(batch_size), targets] = -1e10
            new_probs = F.softmax(orig_logits, dim=1)

            curr_log_probs = F.log_softmax(model(inputs), dim=1)

            model.zero_grad()
            optimizer.zero_grad()
            unl_loss = criterion(curr_log_probs, new_probs)
            unl_loss.backward()
            optimizer.step()

            print("Unlearning Loss: {:.4f}".format(unl_loss.item()))

    return model