import torch

class Uniform(torch.autograd.Function):
    
    @staticmethod
    def forward(
            ctx, cost_mat: torch.FloatTensor,
            nnodes: torch.FloatTensor):
        _, max_nodes, _ = cost_mat.shape

        # Padding mask
        mask_outer1 = (torch.arange(max_nodes, dtype=torch.float32, device=cost_mat.device)[:, None].expand_as(cost_mat)
                       >= nnodes[:, None, None])
        mask_outer2 = (torch.arange(max_nodes, dtype=torch.float32, device=cost_mat.device).expand_as(cost_mat)
                       >= nnodes[:, None, None])
        mask_outer = mask_outer1 | mask_outer2
        
        fill_value = 1.0 / nnodes 
        T = fill_value.view(-1, 1, 1).repeat(1, max_nodes, max_nodes)
        T = T.masked_fill(mask_outer, 0.0)

        ctx.save_for_backward(T)

        return torch.sum(cost_mat * T, dim=[1, 2])

    @staticmethod
    def backward(ctx, grad_output):
        T, = ctx.saved_tensors
        return T * grad_output[:, None, None], None, None, None, None

