"""
This module provides PDE operators. There is no aim for generality.

"""
import tensorflow as tf

def gradient(R,x):
    """
    Computes gradient of R at x.
    
    Parameters
    ----------
    R : function
        R is a function that maps (K,d) tensors to (K,1) tensors.
    x : tensor
        Tensor of shape (K,d) where the derivative is evaluated.
    
    Returns
    -------
    gradRx : tensor
        Tensor of shape (K,d), the derivative of R at x.
    
    """
    with tf.GradientTape() as tape:
        tape.watch(x)
        Rx = R(x)# + tf.reduce_sum(0. * x)
    gradRx = tape.gradient(Rx, x)
    return gradRx

def dot(v,w):
    """
    Inner product of batched vectors.

    Given two tensors of shape (N,d) the inner product is performed
    along the axis 1, so in the rows of the matrix.

    Parameters
    ----------
    v : Tensor
        A tensor of shape (K,d).
    w : Tensor
        A tensor of the same shape as v.
    
    Returns
    -------
    A tensor of shape (K, 1).

    """
    return tf.math.reduce_sum(v*w, axis = 1, keepdims=True)

def grad_square(R, x):
    """
    Computes norm squared of the gradient of R at x.
    
    Parameters
    ----------
    R : function
        R is a function that maps (K,d) tensors to (K,1) tensors.
    x : tensor
        Tensor of shape (K,d) where the derivative is evaluated.
    
    Returns
    -------
    gradRx : tensor
        Tensor of shape (K, 1), the grad norm squared of R at x.
    
    """
    grad = gradient(R, x)
    return dot(grad, grad)

def prime(f, X):
    """
    Derivative in last component of f, used in parametric problems.

    The function takes the derivative of f at X with respect to
    the last dimension of X. This is used for parametric problems 
    based on the convention of ordering the parameters in the form
    (p,x). This method is useful only if the parameter space is
    one-dimensional

    Parameters
    ----------
    f : function
        A function with signature (n, dim) -> (n, 1) operating on
        tensors.
    X : 
        A valid input for f, i.e., a tensor of shape (n, dim), where
        n is any positive integer.

    Returns
    -------
    The derivative of f at X as a tensor of the shape (n,1).

    """
    alpha = X[:,0:-1]
    x = tf.reshape(X[:,-1], shape=(len(X), 1))
    with tf.GradientTape(persistent=True) as t1:
        t1.watch(x)
        fx  = f(tf.reshape(tf.concat([alpha, x], axis=1),shape=tf.shape(X)))
        fx_ = fx + 0. * x
    dfdx = t1.gradient(fx_, x)
    return dfdx
