import torch

### all input and output vectors are belong to the ambient space, i.e., they are (n+1)-dim. vectors for S^n

def project_to_sphere(x):
    return x / torch.sqrt(torch.sum(x*x, axis=1)).view(-1,1)

def project_to_tangentSpace(x, e):
    # assume the norm of x is 1
    return e - x*torch.sum(x*e,axis=1).view(-1,1)

def exponential_map(x, v, eps=1e-10):
    # assume the norm of x is 1, and y are perpendicular to x
    # return Exp_x(v)
    v_norm = torch.sqrt(torch.sum(v*v, axis=1))
    output = x*torch.cos(v_norm).view(-1,1) + v*(torch.sin(v_norm)/v_norm).view(-1,1)
    output[v_norm < eps] = x[v_norm < eps]
    return output

def distance(x, y):
    # assume the norm of x and y are 1
    # acos backward is numerically instable if the input is too close to +1 or -1...
    temp = torch.sum(x*y, axis=1)
    #eps = 1e-5
    #temp[temp>1-eps] = 1-eps
    #temp[temp<-1+eps] = -1+eps
    temp[temp>1] = 1
    temp[temp<-1] = -1
    return torch.acos(temp)

def pairwise_distance(x):
    # assume the norm of x are 1
    temp = (x.unsqueeze(0)*x.unsqueeze(1)).sum(-1)
    temp[temp>1] = 1
    temp[temp<-1] = -1
    return torch.acos(temp)

def logarithm_map(x, y, eps=1e-10, returnDistAlso = False):
    # assume the norm of x and y are 1
    # return Log_x(y)
    temp0 = torch.sum(x*y, axis=1)
    temp0[temp0>1] = 1
    temp0[temp0<-1] = -1
    dist = torch.acos(temp0)
    
    temp = y - x*temp0.view(-1,1)
    temp_norm = torch.sqrt(torch.sum(temp*temp,axis=1))
    output = temp.clone()
    output[temp_norm>eps] = temp[temp_norm>eps] * (dist[temp_norm>eps]/temp_norm[temp_norm>eps]).view(-1,1)
    if returnDistAlso:
        return output, dist
    return output
"""
def GramSchmidtBasis(x):
    if x.is_cuda:
        basis = torch.cuda.FloatTensor(x.shape[0], x.shape[1] - 1, x.shape[1]).zero_()
    else:
        basis = torch.FloatTensor(x.shape[0], x.shape[1] - 1, x.shape[1]).zero_()
    return basis
"""

class acosSquare(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, input, eps=1e-7):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.eps = eps
        input[input > 1.] = 1.
        input[input < -1. + eps] = -1. + eps
        input_acos = input.acos()
        ctx.save_for_backward(input, input_acos)
        return input_acos*input_acos

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, input_acos = ctx.saved_tensors
        doutput_dinput = -2.*input_acos / input_acos.sin()
        doutput_dinput[input > 1.-ctx.eps] = -2. / input[input > 1.-ctx.eps]
        grad_input = grad_output * doutput_dinput
        return grad_input

acos_square = acosSquare.apply