from scipy.optimize import linear_sum_assignment
import torch
from tqdm import tqdm

def hungarian(pc_source, pc_target, return_matching=False):
    '''
        pc_source : [torch.Tensor] B x M x 2
        pc_target : [torch.Tensor] B x M x 2
    '''
    if pc_source.dim() == 2:
        pc_source = pc_source[None]
        pc_target = pc_target[None]
        
    # Compute pairwise distances
    dists = torch.cdist(pc_source, pc_target) # B x N x M
    # dists = dists**2

    costs = []
    assignments = []
    
    # Iterate over batch
    # for dist in tqdm(dists):
    for dist in dists:
        # Compute optimal assignment
        assignment = linear_sum_assignment(dist.detach().cpu().numpy())

        # Compute optimal cost
        cost = torch.sum(dist[assignment])
        costs.append(cost)

        if return_matching:
            assignments.append(assignment)

    costs = torch.stack(costs)

    if not return_matching:
        return costs        
    else:
        return costs, assignments

def hungarian_batched(pc_source, pc_target, return_matching=False, batch_size=128):
    '''
        pc_source : [torch.Tensor] B x M x 2
        pc_target : [torch.Tensor] B x M x 2
    '''
    if pc_source.dim() == 2:
        pc_source = pc_source[None]
        pc_target = pc_target[None]
        
    # Compute pairwise distances
    # dists = torch.cdist(pc_source, pc_target) # B x N x M

    costs = []
    assignments = []
    
    # Iterate over batch
    # for idx in tqdm(range(0, len(pc_source), batch_size)):
    for idx in range(0, len(pc_source), batch_size):
        # Compute pairwise distances
        dists = torch.cdist(pc_source[idx:idx+batch_size], pc_target[idx:idx+batch_size]) # B x N x M

        # for dist in tqdm(dists):
        for dist in dists:
            # Compute optimal assignment
            assignment = linear_sum_assignment(dist.detach().cpu().numpy())

            # Compute optimal cost
            cost = torch.sum(dist[assignment])
            costs.append(cost)

            if return_matching:
                assignments.append(assignment)

    costs = torch.stack(costs)

    if not return_matching:
        return costs        
    else:
        return costs, assignments
    
def hungarian_batched_grads(pc_source, pc_target, return_matching=False, batch_size=32):
    '''
        pc_source : [torch.Tensor] B x M x 2
        pc_target : [torch.Tensor] B x M x 2
    '''
    if pc_source.dim() == 2:
        pc_source = pc_source[None]
        pc_target = pc_target[None]
        
    # Compute pairwise distances
    # dists = torch.cdist(pc_source, pc_target) # B x N x M

    costs = []
    assignments = []
    grads = []
    
    # Iterate over batch
    for idx in tqdm(range(0, len(pc_source), batch_size)):
    # for idx in range(0, len(pc_source), batch_size):

        pc_source_batch = pc_source[idx:idx+batch_size]
        pc_target_batch = pc_target[idx:idx+batch_size]

        pc_source_batch = pc_source_batch.cuda()
        pc_target_batch = pc_target_batch.cuda()
        pc_target_batch.requires_grad = True

        # Compute pairwise distances
        dists = torch.cdist(pc_source_batch, pc_target_batch) # B x N x M

        costs_batch = []

        #for dist in tqdm(dists):
        for dist in dists:
            # Compute optimal assignment
            assignment = linear_sum_assignment(dist.detach().cpu().numpy())

            # Compute optimal cost
            cost = torch.sum(dist[assignment])
            costs_batch.append(cost)

            if return_matching:
                assignments.append(assignment)

        costs_batch = torch.stack(costs_batch)
        costs.append(costs_batch)

        loss = costs_batch.mean()
        loss.backward()
        true_grad = pc_target_batch.grad
        grads.append(true_grad)

    costs = torch.cat(costs, dim=0).detach().cpu()
    grads = torch.cat(grads, dim=0).cpu()

    if not return_matching:
        return costs, grads        
    else:
        return costs, grads, assignments

        
# chamfer distance
def chamfer(x, y, return_matching=False):
    '''
    x: (B, N, 2)
    y: (B, M, 2)

    '''
    if x.dim() == 2:
        x = x[None]
    if y.dim() == 2:
        y = y[None]

    
    dist = torch.cdist(x, y) # (B, N, M)
    dist  = dist**2

    left_dist, left_idx = dist.min(2) # (B, N), (B, N)
    left_dist = left_dist.sum(1)
    
    right_dist, right_idx = dist.min(1) # (B, M), (B, M)
    right_dist = right_dist.sum(1)

    loss = left_dist + right_dist # (B, )

    if return_matching:
        return loss, left_dist, left_idx, right_dist, right_idx
    else:
        return loss
    
# chamfer distance
def chamferL2(x, y, return_matching=False):
    '''
    x: (B, N, 2)
    y: (B, M, 2)

    '''
    if x.dim() == 2:
        x = x[None]
    if y.dim() == 2:
        y = y[None]

    
    dist = torch.cdist(x, y) # (B, N, M)
    # dist  = dist**2

    left_dist, left_idx = dist.min(2) # (B, N), (B, N)
    left_dist = left_dist.sum(1)
    
    right_dist, right_idx = dist.min(1) # (B, M), (B, M)
    right_dist = right_dist.sum(1)

    loss = left_dist + right_dist # (B, )

    if return_matching:
        return loss, left_dist, left_idx, right_dist, right_idx
    else:
        return loss
