# adopted from https://gist.github.com/WangZixuan/4c4cdf49ce9989175e94524afc946726
# and slightly altered

import torch


def chamfer_distance_with_batch(p1, p2, debug=False):


    '''
    Calculate Chamfer Distance between two point sets
    :param p1: size[B, N, D]
    :param p2: size[B, M, D]
    :param debug: whether need to output debug info
    :return: sum of all batches of Chamfer Distance of two point sets
    '''

    assert p1.size(0) == p2.size(0) and p1.size(1) == p2.size(1)

    if debug:
        print(p1[0])

    p1 = p1.unsqueeze(2).to(device="cuda")
    p2 = p2.unsqueeze(3).to(device="cuda")
    if debug:
        print('p1 size is {}'.format(p1.size()))
        print('p2 size is {}'.format(p2.size()))
        print(p1[0][0])

    dist = torch.add(p1, torch.neg(p2))
    if debug:
        print('dist size is {}'.format(dist.size()))
        print(dist[0])

    dist = torch.norm(dist, 2, dim=1)
    if debug:
        print('dist size is {}'.format(dist.size()))
        print(dist)

    dist1 = torch.min(dist, dim=-2)[0]
    if debug:
        print('dist1 size is {}'.format(dist1.size()))
        print(dist1)

    dist2 = torch.min(dist, dim=-1)[0]
    if debug:
        print('dist2 size is {}'.format(dist2.size()))
        print(dist2)

    dist1 = torch.sum(dist1 ** 2, dim=1).unsqueeze(-1)
    dist2 = torch.sum(dist2 ** 2, dim=1).unsqueeze(-1)
    if debug:
        print('-------')
        print(dist1, dist2)

    dist = torch.sum(torch.max(torch.cat((dist1, dist2), 1), dim=1)[0])

    return dist

def chamfer_distance_with_batch_mean(p1, p2, debug=False):


    '''
    Calculate Chamfer Distance between two point sets
    :param p1: size[B, N, D]
    :param p2: size[B, M, D]
    :param debug: whether need to output debug info
    :return: sum of all batches of Chamfer Distance of two point sets
    '''

    assert p1.size(0) == p2.size(0) and p1.size(1) == p2.size(1)

    if debug:
        print(p1[0])

    p1 = p1.unsqueeze(2).to(device="cuda")
    p2 = p2.unsqueeze(3).to(device="cuda")
    if debug:
        print('p1 size is {}'.format(p1.size()))
        print('p2 size is {}'.format(p2.size()))
        print(p1[0][0])

    dist = torch.add(p1, torch.neg(p2))
    if debug:
        print('dist size is {}'.format(dist.size()))
        print(dist[0])

    dist = torch.norm(dist, 2, dim=1)
    if debug:
        print('dist size is {}'.format(dist.size()))
        print(dist)

    dist1 = torch.min(dist, dim=-2)[0]
    if debug:
        print('dist1 size is {}'.format(dist1.size()))
        print(dist1)

    dist2 = torch.min(dist, dim=-1)[0]
    if debug:
        print('dist2 size is {}'.format(dist2.size()))
        print(dist2)

    dist1 = torch.sum(dist1, dim=1).unsqueeze(-1) / dist1.shape[1]
    dist2 = torch.sum(dist2, dim=1).unsqueeze(-1) / dist2.shape[1]

    dist = torch.sum(torch.max(torch.cat((dist1, dist2), 1), dim=1)[0])

    if debug:
        print('-------')
        print(dist1, dist2)
        print(dist.shape)

    return dist / dist1.shape[0]