import numpy
import torch

from helper import cubic_func, utils

def damped_newton(model,
                  forget_set,
                  retain_set,
                  config,
                  **kwargs,
                  ):

    assert not config.llama, "Damped Newton can't be run on Llama."

    train_loader = utils.get_dataloader(retain_set,
                                        shuffle=True,
                                        batch_size=config.train_batch_size)

    tuple_params = tuple(param for param in model.parameters() if param.requires_grad)
    loss_fn = getattr(torch.nn, config.loss)()

    grad = cubic_func.gradient(model, 
                               loss_fn, 
                               train_loader, 
                               device=config.device)
    grad = -cubic_func.compose_param_vector(grad, tuple_params)
    grad = grad.detach().cpu().numpy()
    hess = cubic_func.hessian(model, 
                              loss_fn, 
                              train_loader,
                              device=config.device)
    hess = cubic_func.compose_param_matrix(hess, tuple_params)
    hess = hess.detach().cpu().numpy()
    
    dw = numpy.linalg.inv(hess + config.gamma * numpy.eye(*hess.shape)) @ grad
    tuple_param_update = cubic_func.decompose_param_vector(dw, tuple_params)
    i = 0
    for p in model.parameters():
        if p.requires_grad:
            p.data += tuple_param_update[i]
            i += 1
