import torch
import math


def calc_chamfer(dist_matrix, nnodes):
    _, max_nodes, _ = dist_matrix.shape

    # Set outer matrix to inf
    mask_outer1 = (torch.arange(max_nodes, dtype=torch.float32, device=nnodes.device)[:, None].expand_as(dist_matrix)
                   >= nnodes[:, None, None])
    mask_outer2 = (torch.arange(max_nodes, dtype=torch.float32, device=nnodes.device).expand_as(dist_matrix)
                   >= nnodes[:, None, None])
    mask_outer = mask_outer1 | mask_outer2
    dist_inf = dist_matrix.masked_fill(mask_outer, math.inf)

    min_in, _ = torch.min(dist_inf, dim=1)
    min_in_noinf = torch.where(min_in < math.inf, min_in, min_in.new_zeros(1))
    sum_min_in = torch.sum(min_in_noinf, dim=1)
    min_out, _ = torch.min(dist_inf, dim=2)
    min_out_noinf = torch.where(min_out < math.inf, min_out, min_in.new_zeros(1))
    sum_min_out = torch.sum(min_out_noinf, dim=1)
    return 0.5 * (sum_min_in + sum_min_out)

# TODO
# Try generating matching matrix explicitly and implement backwards as other matchings do
