import pytest

import torch

from unbalancedgw.utils import generate_measure
from unbalancedgw.vanilla_ugw_solver import log_ugw_sinkhorn, exp_ugw_sinkhorn

torch.manual_seed(42)