# -*- coding: utf-8 -*-
"""
 Implements a 1d monotone spline s: [0,1] -> [0,1] parameterized by alpha and beta such that
 pos = [0, cumsum(softmax(alpha))] and
 vals = [0, cumsum(softmax(beta))]
 determine the kinks of the spline, i.e., 
 s(pos[i]) = vals[i]
 and the behavior is linear in between the kinks. 
 The cumsum and softmax ensure the mapping is invertible. 
"""
import torch
import torch.nn.functional as F
import functorch

def __eval(x,pos,vals):
    x = x.clip(pos[0], pos[-1])  ## if x is out of range, just clip. This avoids rounding errors.
    if x.dim()==0:
        x=x.view(1)
    
    assert x.dim()==1            ## throw an error if the input is not 1-dimensional
    
    diffs = pos[:,None]-x[None,:]
    
    indices = torch.argmax((F.relu(diffs)>0).float(),dim=0) 
    slopes = (vals[indices]-vals[indices-1])/(pos[indices]-pos[indices-1])
    return vals[indices-1] + slopes*(x - pos[indices-1]), slopes


def get_values_and_positions(theta: torch.Tensor):
    assert theta.dim()==1
    n = theta.shape[0]
    assert n%2==0
    
    pos = torch.cat((torch.tensor([0.0], device=theta.device),torch.cumsum(torch.nn.functional.softmax(theta[0:int(n/2)], dim = 0),dim =0)))
    vals = torch.cat((torch.tensor([0.0], device=theta.device),torch.cumsum(torch.nn.functional.softmax(theta[int(n/2):], dim = 0),dim =0)))
    return pos,vals

    
def evaluate(x: torch.Tensor,theta: torch.Tensor)->tuple[torch.Tensor,torch.Tensor]:
    """
    Evaluates the spline at x, which is assumed to by a one dimensional tensor of positions in [0,1]
    returns not only the evaluation, but also the slopes (i.e. the derivatives w.r.t. x)
    """
    pos,vals = get_values_and_positions(theta)
    return __eval(x,pos,vals)

def evaluateInverse(x:torch.Tensor,theta: torch.Tensor)->tuple[torch.Tensor,torch.Tensor]:
    """
    Evaluates the inverse spline at x, which is assumed to by a one dimensional tensor of positions in [0,1]
    returns not only the evaluation, but also the slopes (i.e. the derivatives w.r.t. x)
    """
    pos,vals = get_values_and_positions(theta)
    return __eval(x,vals,pos)

def dfdt(x:torch.Tensor,theta: torch.Tensor)->torch.Tensor:
    """ 
    Compute the derivative of the deformation with respect to theta
    :param X: input tensor, shape (*, 2)
    :param theta: parameter of deformation
    :return: derivative of the deformation with respect to theta, shape (*, 2, dim(theta))
    """
    my_fun = (lambda theta,x: evaluate(x,theta)[0][0])
    my_grad = functorch.grad(my_fun)
    my_theta_gradient = functorch.vmap(my_grad, in_dims=(None,0))
    return my_theta_gradient(theta,x)
  