from . import ot_problem, solvers
import numpy as np
import torch
import torch.nn as nn
from scipy.special import softmax, logsumexp
