import torch


def batched_cdist(x: torch.Tensor, y: torch.Tensor, p: float = 2.0, magic_number: float = 2.5e+8,
                  device='cuda') -> torch.Tensor:
    _device = x.device
    if magic_number < 0:
        return torch.cdist(x.to(device), y.to(device), p=p).to(_device)
    else:
        dist = []
        magic_number = int(magic_number / len(y))
        for batch in torch.split(x, magic_number):
            dist.append(torch.cdist(batch[None].to(device), y[None].to(device), p=p)[0].to(_device))
        return torch.cat(dist, dim=0)
