# Code from https://raw.githubusercontent.com/OPTML-Group/Unlearn-Saliency/refs/heads/master/Classification/unlearn/Wfisher.py
import torch
from torch.autograd import grad
from tqdm import tqdm


def get_require_grad_params(model: torch.nn.Module, named=False):
    if named:
        return [
            (name, param)
            for name, param in model.named_parameters()
            if param.requires_grad
        ]
    else:
        return [param for param in model.parameters() if param.requires_grad]


def get_sample_grad(model, loss):
    params = []

    for param in get_require_grad_params(model, named=False):
        params.append(param)

    sample_grad = grad(loss, params)
    sample_grad = [x.view(-1) for x in sample_grad]

    return torch.cat(sample_grad)


def apply_perturb(model, v, mask=None):
    curr = 0
    if mask:
        for name, param in get_require_grad_params(model, named=True):
            length = param.view(-1).shape[0]
            param.view(-1).data += v[curr : curr + length].data * mask[name].view(-1)
            curr += length

    else:
        for param in get_require_grad_params(model, named=False):
            length = param.view(-1).shape[0]
            param.view(-1).data += v[curr : curr + length].data
            curr += length


def fisher(model, train_dl, device, criterion, v):
    model.eval()
    k_vec = torch.clone(v)
    N = 1000
    o_vec = None
    for idx, (data, label) in enumerate(tqdm(train_dl)):
        model.zero_grad()
        data = data.to(device)
        label = label.to(device)
        output = model(data)

        loss = criterion(output.logits, label)
        sample_grad = get_sample_grad(model, loss)
        with torch.no_grad():
            if o_vec is None:
                o_vec = torch.clone(sample_grad)
            else:
                tmp = torch.dot(o_vec, sample_grad)
                k_vec -= (torch.dot(k_vec, sample_grad) / (N + tmp)) * o_vec
                o_vec -= (tmp / (N + tmp)) * o_vec
        if idx > N:
            return k_vec
    return k_vec


def train_unlearn(
    forget_loader,
    retain_loader,
    batch_1_retain_loader,
    model,
    loss_fn,
    device,
    alpha,
    mask=None,
    save_path: None | str = None,
):
    model.to(device)
    params = []

    for param in get_require_grad_params(model, named=False):
        params.append(param.view(-1))

    forget_grad = torch.zeros_like(torch.cat(params)).to(device)
    retain_grad = torch.zeros_like(torch.cat(params)).to(device)

    total = 0
    model.eval()

    for data, label in tqdm(forget_loader):
        model.zero_grad()
        try:
            real_num = data.input_ids.shape[0]
        except:
            real_num = data.shape[0]
        data = data.to(device)
        label = label.to(device)
        output = model(data)

        loss = loss_fn(output.logits, label)
        f_grad = get_sample_grad(model, loss) * real_num
        forget_grad += f_grad
        total += real_num

    total_2 = 0
    for data, label in tqdm(retain_loader):
        model.zero_grad()
        try:
            real_num = data.input_ids.shape[0]
        except:
            real_num = data.shape[0]
        data = data.to(device)
        label = label.to(device)
        output = model(data)

        loss = loss_fn(output.logits, label)
        r_grad = get_sample_grad(model, loss) * real_num
        retain_grad += r_grad
        total_2 += real_num

    retain_grad *= total / ((total + total_2) * total_2)
    forget_grad /= total + total_2

    perturb = fisher(
        model,
        batch_1_retain_loader,
        device=device,
        criterion=loss_fn,
        v=forget_grad - retain_grad,
    )

    apply_perturb(model, alpha * perturb, mask=mask)

    if save_path is not None:
        torch.save(
            {
                "best_epoch": None,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": None,
                "loss": None,
            },
            save_path,
        )

    return model
