# Adapted from https://github.com/crispitagorico/torchspde
import torch
import numpy as np
from .neural_spde import NeuralSPDE
from torch.autograd import grad

def grad_var(u_i, grid_var):
    """ Input:
              - u_i (batch, dim_x, dim_y, dim_t)
              - grid_var (batch, dim_x, dim_y, dim_t)
        Returns:
              - du_i/dvar(grid_var)  (batch, dim_x, dim_y, dim_t)
    """
    return grad(u_i.sum(), grid_var, create_graph=True)[0] 


def grad_space(u_i, gridx, gridy):
    """ Input:
              - u_i (batch, dim_x, dim_y, dim_t)
              - gridx, gridy (batch, dim_x, dim_y, dim_t)
        Returns:
              - (du_i/dx, du_i/dy)  (batch, dim_x, dim_y, dim_t, 2)
    """
    return torch.stack([grad_var(u_i, gridx), grad_var(u_i, gridy)], dim=-1)


def grad_space_perp(u_i, gridx, gridy):
    """ Input:
              - u_i (batch, dim_x, dim_y, dim_t)
              - gridx, gridy (batch, dim_x, dim_y, dim_t)
        Returns:
              - (-du_i/dy, du_i/dx)  (batch, dim_x, dim_y, dim_t, 2)
    """
    return torch.stack([-grad_var(u_i, gridy), grad_var(u_i, gridx)], dim=-1)


def laplacian(u_i, gridx, gridy):
    """ Input:
              - u_i (batch, dim_x, dim_y, dim_t)
              - gridx, gridy (batch, dim_x, dim_y, dim_t)
        Returns:
              - d^2u_i/dx^2 + d^2u_i/dy_2  (batch, dim_x, dim_y, dim_t)
    """
    return grad_var(grad_var(u_i, gridx), gridx) + grad_var(grad_var(u_i, gridy), gridy)

