import torch
import numpy as np

from helper import cubic_func, utils

def pinv_newton(model,
                forget_set,
                retain_set,
                config,
                **kwargs,
                ):
    assert not config.llama, "PINV Newton can't be run on Llama."

    train_loader = utils.get_dataloader(retain_set,
                                        shuffle=True,
                                        batch_size=config.train_batch_size)

    tuple_params = tuple(param for param in model.parameters() if param.requires_grad)
    loss_fn = getattr(torch.nn, config.loss)()

    grad = cubic_func.gradient(model, 
                               loss_fn, 
                               train_loader,
                               device=config.device)
    grad = -cubic_func.compose_param_vector(grad, tuple_params)
    grad = grad.detach().cpu().numpy()
    hess = cubic_func.hessian(model, 
                              loss_fn, 
                              train_loader,
                              device=config.device)
    hess = cubic_func.compose_param_matrix(hess, tuple_params)
    hess = hess.detach().cpu().numpy()
    
    dw = np.linalg.pinv(hess, hermitian=True) @ grad
    tuple_param_update = cubic_func.decompose_param_vector(dw, tuple_params)
    i = 0
    for p in model.parameters():
        if p.requires_grad:
            p.data += tuple_param_update[i]
            i += 1

    # else:
    #     trainer.train_dataset = dr_set
    #     dr_loader = trainer.get_train_dataloader()
    #     loss_fn = torch.nn.CrossEntropyLoss()
    #     for batch in dr_loader:
    #         tuple_params = tuple(p for p in model.parameters() if p.requires_grad)
    #         g = cubic_func_transformers.gradient(cfg, model, trainer, loss_fn, batch)
    #         H = cubic_func_transformers.hessian(cfg, model, trainer, loss_fn, batch)
    #         g = -cubic_func_transformers.compose_param_vector(g, tuple_params).cpu().detach().numpy()
    #         H = cubic_func_transformers.compose_param_matrix(H, tuple_params).cpu().detach().numpy()
    #         dw = numpy.linalg.pinv(H, hermitian=True) @ g
    #         tuple_param_update = cubic_func_transformers.decompose_param_vector(dw, tuple_params)
    #         i = 0
    #         for p in model.parameters():
    #             if p.requires_grad:
    #                 p.data += tuple_param_update[i]
    #                 i += 1
