import torch
from torch.autograd import Function
from metrics.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad

# Inherit from Function
class NNDistanceFunction(Function):
    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, seta, setb):
        #print("Match Cost Forward")
        ctx.save_for_backward(seta, setb)
        '''
        input:
	        set1 : batch_size * #dataset_points * 3
	        set2 : batch_size * #query_points * 3
        returns:
	        dist1, idx1, dist2, idx2
        '''
        dist1, idx1, dist2, idx2 = NNDistance(seta, setb)
        ctx.idx1 = idx1
        ctx.idx2 = idx2
        return dist1, dist2

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_dist1, grad_dist2):
        #print("Match Cost Backward")
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        seta, setb = ctx.saved_tensors
        idx1 = ctx.idx1
        idx2 = ctx.idx2
        grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1.contiguous(), grad_dist2.contiguous())
        return grada, gradb

nn_distance = NNDistanceFunction.apply

