import math
import torch


def sinkhorn(K, eps=0.05, max_iter=10):
    K = torch.exp(K / eps)
    m, n = K.shape
    u = K.new_ones((m,))
    v = K.new_ones((n,))
    a = float(m / n)
    for _ in range(max_iter):
        v = a / u.view(1, -1).mm(K)
        u = 1. / (K.mm(v.view(-1, 1)))
    out = u.view(-1, 1) * (K * v.view(1, -1))
    return out

def log_sinkhorn(K, a=None, b=None, eps=1.0, max_iter=10):
    m, n = K.shape
    v = K.new_zeros((m,))
    if a is None:
        a = 0#-math.log(m)
    else:
        a = torch.log(a)
    if b is None:
        b = math.log(m / n)#-math.log(n)
    else:
        b = torch.log(b)

    K = K / eps

    for _ in range(max_iter):
        u = -torch.logsumexp(v.view(m, 1) + K, dim=0) + b
        v = -torch.logsumexp(u.view(1, n) + K, dim=1) + a

    return torch.exp(K + u.view(1, n) + v.view(m, 1))
