import torch
from torch.autograd import grad
from torch.nn import functional as F
from torch.autograd.functional import vhp


def gradient_fy(args, labels, params, data, output):
    loss = F.cross_entropy(output, labels)
    grad = torch.autograd.grad(loss, params, retain_graph=False)[0]
    return grad

def gradient_gy(args, labels_cp, params, data, hparams, output, reg_f):
    # For MNIST data-hyper cleaning experiments and MNIST l2reg exp
    loss = F.cross_entropy(output, labels_cp, reduction='none')
    # For NewsGroup l2reg expriments
    #loss = F.cross_entropy(output, labels_cp)
    loss_regu = reg_f(params, hparams, loss)
    grad = torch.autograd.grad(loss_regu, params, retain_graph=False)[0]
    return grad

def gradient_gx(args, labels_cp, params, data, hparams, output, reg_f):
    # For MNIST data-hyper cleaning experiments and MNIST l2reg exp
    #loss = F.cross_entropy(output, labels_cp, reduction='none')
    # For NewsGroup l2reg expriments
    #loss = F.cross_entropy(output, labels_cp)
    #loss_regu = reg_f(params, hparams, loss)
    grad =  torch.exp(hparams[0]) * (params[0] ** 2)
    return grad

