import torch
from torch.autograd import grad 
from torch.nn import functional as F

# pre-fucntions
def gradient_fy(args, labels, params, data, output): # tested
    loss = F.cross_entropy(output, labels)
    grad = torch.autograd.grad(loss, params)[0]
    return grad

def gradient_gy(args, labels_cp, params, data, hparams, output, reg_f): # tested
    # For MNIST data-hyper cleaning experiments
    loss = F.cross_entropy(output, labels_cp, reduction='none')
    # For NewsGroup l2reg expriments
    # loss = F.cross_entropy(output, labels_cp)
    loss_regu = reg_f(params, hparams, loss)
    grad = torch.autograd.grad(loss_regu, params, create_graph=True)[0] # y
    return grad

def gradient_gx(args, labels_cp, params, data, hparams, output, reg_f): # tested
    # For MNIST data-hyper cleaning experiments
    loss = F.cross_entropy(output, labels_cp, reduction='none')
    # For NewsGroup l2reg expriments
    # loss = F.cross_entropy(output, labels_cp)
    loss_regu = reg_f(params, hparams, loss)
    grad = torch.autograd.grad(loss_regu, hparams, create_graph=True)[0] # x
    return grad, loss_regu

def tribo(params, hparams, val_data_list, args, out_f, reg_f, v): # add v here; unify params
    data_list, labels_list = val_data_list # same
    # Gyx_gradient
    output = out_f(data_list[0], params)
    Gy_gradient = gradient_gy(args, labels_list[0], params, data_list[0], hparams, output, reg_f) # tested
    Gy_gradient = torch.reshape(Gy_gradient, [-1]) # tested
    Gyx_gradient = torch.autograd.grad(torch.matmul(Gy_gradient, v.detach()), hparams, retain_graph=True)[0] # tested
    outer_update = -Gyx_gradient 
    return outer_update

def tribo_R(params, hparams, val_data_list, args, out_f, reg_f, v): # tested
    data_list, labels_list = val_data_list
    output = out_f(data_list[0], params)
    Fy_gradient = gradient_fy(args, labels_list[0], params, data_list[0], output)
    output = out_f(data_list[1], params)
    Gy_gradient = gradient_gy(args, labels_list[1], params, data_list[1], hparams, output, reg_f) 
    Gy_gradient = torch.reshape(Gy_gradient, [-1])
    Gyy_gradient = torch.autograd.grad(torch.matmul(Gy_gradient, v.detach()), params, retain_graph=True)[0] # tested
    outer_update = Gyy_gradient * 0.5 - Fy_gradient
    outer_update = torch.unsqueeze(torch.reshape(outer_update, [-1]), 1).detach()
    return outer_update

def tribo_new(params, hparams, val_data_list, args, out_f, reg_f, v, ls_lr, eta_R, htR_old, grad_v_old):
    data_list, labels_list = val_data_list
    output = out_f(data_list[1], params)
    Gy_gradient = gradient_gy(args, labels_list[1], params, data_list[1], hparams, output, reg_f) 

    G_gradient = torch.reshape(params, [-1])
    Jacobian = torch.matmul(G_gradient, v)
    grad_v = torch.autograd.grad(Jacobian, params, retain_graph=True)[0]
    grad_v = torch.unsqueeze(torch.reshape(grad_v, [-1]), 1).detach()
    
    htR = grad_v + (1 - eta_R) * (htR_old - grad_v_old)
    v_new = v - ls_lr * grad_v

    # Gyx_gradient
    output = out_f(data_list[2], params)
    Gy_gradient = gradient_gy(args, labels_list[2], params, data_list[2], hparams, output, reg_f)
    Gy_gradient = torch.reshape(Gy_gradient, [-1])
    Gyx_gradient = torch.autograd.grad(torch.matmul(Gy_gradient, v_new.detach()), hparams, retain_graph=True)[0]
    outer_update = -Gyx_gradient 

    return outer_update, v_new, htR, grad_v

def stocbio(params, hparams, val_data_list, args, out_f, reg_f):
    data_list, labels_list = val_data_list
    # Fy_gradient
    output = out_f(data_list[0], params)
    Fy_gradient = gradient_fy(args, labels_list[0], params, data_list[0], output)
    v_0 = torch.unsqueeze(torch.reshape(Fy_gradient, [-1]), 1).detach()

    # Hessian
    z_list = []
    output = out_f(data_list[1], params)
    Gy_gradient = gradient_gy(args, labels_list[1], params, data_list[1], hparams, output, reg_f) 

    G_gradient = torch.reshape(params, [-1]) - args.eta*torch.reshape(Gy_gradient, [-1])
    # G_gradient = torch.reshape(params[0], [-1]) - args.eta*torch.reshape(Gy_gradient, [-1])
    
    for _ in range(args.hessian_q): # inner loop
    # for _ in range(args.K):
        Jacobian = torch.matmul(G_gradient, v_0)
        v_new = torch.autograd.grad(Jacobian, params, retain_graph=True)[0]
        v_0 = torch.unsqueeze(torch.reshape(v_new, [-1]), 1).detach()
        z_list.append(v_0)            
    v_Q = args.eta*v_0+torch.sum(torch.stack(z_list), dim=0)

    # Gyx_gradient
    output = out_f(data_list[2], params)
    Gy_gradient = gradient_gy(args, labels_list[2], params, data_list[2], hparams, output, reg_f)
    Gy_gradient = torch.reshape(Gy_gradient, [-1])
    Gyx_gradient = torch.autograd.grad(torch.matmul(Gy_gradient, v_Q.detach()), hparams, retain_graph=True)[0]
    outer_update = -Gyx_gradient 

    return outer_update



