import numpy as np
import torch
from torch.autograd import grad


def laplace(y, x):
    grad = gradient(y, x)
    return divergence(grad, x)


def divergence(y, x, x_offset=1):
    div = 0.
    for i in range(y.shape[-1]):
        tmp_grad = grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0]
        div += tmp_grad[..., [i + x_offset]]
    return div


def gradient(y, x, grad_outputs=None):
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
    return grad


def custom_hessian(y, x):
    num_observations = y.shape[0]
    hess = torch.zeros(num_observations, y.shape[-1], x.shape[-1], x.shape[-1]).to(y.device)
    grad_y = torch.ones_like(y[..., 0]).to(y.device)

    for output_i in range(y.shape[-1]):
        # calculate dydx over batches for each feature value of y
        dydx = torch.autograd.grad(y[..., output_i], x, grad_y, create_graph=True)[0]

        # calculate hessian on y for each x value
        for dim_j in range(x.shape[-1]):
            hess[..., output_i, dim_j, :] = grad(dydx[..., dim_j], x, grad_y, create_graph=True)[0][..., :]

    status = 0
    if torch.any(torch.isnan(hess)):
        status = -1
    return hess, status


def hessian(y, x):
    ''' hessian of y wrt x
    y: shape (meta_batch_size, num_observations, channels)
    x: shape (meta_batch_size, num_observations, 2)
    '''
    meta_batch_size, num_observations = y.shape[:2]
    grad_y = torch.ones_like(y[..., 0]).to(y.device)
    h = torch.zeros(meta_batch_size, num_observations, y.shape[-1], x.shape[-1], x.shape[-1]).to(y.device)
    for i in range(y.shape[-1]):
        # calculate dydx over batches for each feature value of y
        dydx = grad(y[..., i], x, grad_y, create_graph=True)[0]

        # calculate hessian on y for each x value
        for j in range(x.shape[-1]):
            h[..., i, j, :] = grad(dydx[..., j], x, grad_y, create_graph=True)[0][..., :]

    status = 0
    if torch.any(torch.isnan(h)):
        status = -1
    return h, status


def gradients(outputs, inputs):
    return torch.autograd.grad(outputs, inputs, grad_outputs=torch.ones_like(outputs), create_graph=True)


def rmse(x1, x2):
    """
    RMSE between two numpy arrays
    Args:
        x1 ():
        x2 ():

    Returns:

    """
    return np.mean((x1 - x2) ** 2) ** 0.5
