import torch
from torch.autograd import grad
import torch.nn.functional as F


def hessian(y, x):
    ''' hessian of y wrt x
    y: shape (meta_batch_size, num_observations, channels)
    x: shape (meta_batch_size, num_observations, 2)
    '''
    meta_batch_size, num_observations = y.shape[:2]
    grad_y = torch.ones_like(y[..., 0]).to(y.device)
    h = torch.zeros(meta_batch_size, num_observations, y.shape[-1], x.shape[-1], x.shape[-1]).to(y.device)
    for i in range(y.shape[-1]):
        # calculate dydx over batches for each feature value of y
        dydx = grad(y[..., i], x, grad_y, create_graph=True)[0]

        # calculate hessian on y for each x value
        for j in range(x.shape[-1]):
            h[..., i, j, :] = grad(dydx[..., j], x, grad_y, create_graph=True)[0][..., :]

    status = 0
    if torch.any(torch.isnan(h)):
        status = -1
    return h, status


def factorized_directional_vector_field_derivative(y, x, direction):
    # direction = F.normalize(direction, dim=-1) # Direction of vector field (norm 1)

    dir_derivative = 0
    for i in range(y.shape[-1]): # for each color channel
        pd = grad(y[..., i:i+1], x, torch.ones_like(y[..., i:i+1]), create_graph=True)[0] # Gradient of color w.r.t. coordinates
        dir_derivative += torch.abs(torch.einsum('...j,...j', pd, direction)) # Inner product of gradient w. direction
    return dir_derivative


def factorized_directional_vector_field_derivative(y, x, direction):
    # direction = F.normalize(direction, dim=-1) # Direction of vector field (norm 1)

    dir_derivative = 0
    for i in range(y.shape[-1]): # for each color channel
        pd = grad(y[..., i:i+1], x, torch.ones_like(y[..., i:i+1]), create_graph=True)[0] # Gradient of color w.r.t. coordinates
        dir_derivative += torch.abs(torch.einsum('...j,...j', pd, direction)) # Inner product of gradient w. direction
    return dir_derivative


def directional_vector_field_derivative(y, x):
    vector_direction = F.normalize(y, dim=-1)

    dir_derivative = 0
    for j in range(y.shape[-2]):
        for i in range(y.shape[-1]):
            pd = grad(y[..., j, i:i+1], x, torch.ones_like(y[..., j, i:i+1]), create_graph=True)[0]
            dir_derivative += torch.abs(torch.einsum('...j,...j', pd, vector_direction[..., j, :]))
    return dir_derivative


def laplace(y, x):
    grad = gradient(y, x)
    return divergence(grad, x)


def divergence(y, x):
    div = 0.
    for i in range(y.shape[-1]):
        div += grad(y[..., i], x, torch.ones_like(y[..., i]), create_graph=True)[0][..., i:i+1]
    return div


def gradient(y, x, grad_outputs=None, create_graph=True):
    if grad_outputs is None:
        grad_outputs = torch.ones_like(y)
    grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=create_graph)[0]
    return grad


def jacobian(y, x):
    ''' jacobian of y wrt x '''
    jac_shape = list(y.shape) + x.shape[-1]
    jac = torch.zeros(jac_shape).to(y.device)
    for i in range(y.shape[-1]):
        # calculate dydx over batches for each feature value of y
        y_flat = y[..., i].view(-1, 1)
        jac[..., i, :] = grad(y_flat, x, torch.ones_like(y_flat), create_graph=True)[0]

    status = 0
    if torch.any(torch.isnan(jac)):
        status = -1

    return jac, status




