import torch

def unbalanced_sinkhorn(
    cost, u, v, ws, wt, logmp, eps, rho_s, rho_t, niters, tol
):
    """
    Scaling algorithm (ie Sinkhorn algorithm).
    Code adapted from Séjourné et al 2020:
    https://github.com/thibsej/unbalanced_gromov_wasserstein.
    """

    ws_dot_wt = ws[:, None] * wt[None, :]
    log_ws = ws.log()
    log_wt = wt.log()

    tau_s = 1 if torch.isinf(rho_s) else rho_s / (rho_s + eps)
    tau_t = 1 if torch.isinf(rho_t) else rho_t / (rho_t + eps)

    pi_diff = None
    idx = 0
    while (pi_diff is None or pi_diff >= tol) and (
        niters is None or idx < niters
    ):
        u_prev, v_prev = u.detach().clone(), v.detach().clone()
        if rho_t == 0:
            v = torch.zeros_like(v)
        else:
            v = -tau_t * ((u + log_ws)[:, None] - cost / eps).logsumexp(
                dim=0
            )

        if rho_s == 0:
            u = torch.zeros_like(u)
        else:
            u = -tau_s * ((v + log_wt)[None, :] - cost / eps).logsumexp(
                dim=1
            )

        if tol is not None and idx % 10 == 0:
            pi_diff = max(
                (u - u_prev).abs().max(), (v - v_prev).abs().max()
            )

        idx += 1

    pi = ws_dot_wt * (u[:, None] + v[None, :] - cost / eps).exp()

    return (u, v), pi

def balanced_sinkhorn_stable(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
                        warmstart=None, verbose=False, print_period=20,
                        log=False, cuda=True, **kwargs):

    if len(a) == 0:
        a = torch.ones((M.shape[0],)) / M.shape[0]
    if len(b) == 0:
        b = torch.ones((M.shape[1],)) / M.shape[1]

    # test if multiple target
    if len(b.shape) > 1:
        nbb = b.shape[1]
        a = a.unsqueeze(1)
    else:
        nbb = 0

    # init data
    na = len(a)
    nb = len(b)

    cpt = 0
    if log:
        log = {'err': []}

    # we assume that no distances are null except those of the diagonal of distances
    if warmstart is None:
        alpha, beta = torch.zeros(na), torch.zeros(nb)
    else:
        alpha, beta = warmstart

    if nbb:
        u = torch.ones(na, nbb) / na
        v = torch.ones(nb, nbb) / nb
    else:
        u = torch.ones(na) / na
        v = torch.ones(nb) / nb

    if cuda:
        u, v = u.cuda(), v.cuda()
        alpha, beta = alpha.cuda(), beta.cuda()

    def get_K(alpha, beta):
        """log space computation"""
        return torch.exp(-(M - alpha.reshape(na, 1) -
                        beta.reshape(1, nb)) / reg)

    def get_Gamma(alpha, beta, u, v):
        """log space gamma computation"""
        return torch.exp(-(M - alpha.reshape(na, 1) - beta.reshape(1, nb)) /
                      reg + torch.log(u.reshape(na, 1)) + torch.log(v.reshape(1, nb)))
    
    K = get_K(alpha, beta)
    transp = K
    loop = 1
    cpt = 0
    err = 1

    while loop:

        uprev = u
        vprev = v

        # sinkhorn update
        v = b / (torch.mv(torch.t(K), u) + 1e-16)
        u = a / (torch.mv(K, v) + 1e-16)

        # remove numerical problems and store them in K
        if torch.abs(u).max() > tau or torch.abs(v).max() > tau:
            if nbb:
                alpha, beta = alpha + reg * \
                    torch.max(torch.log(u), 1), beta + reg * torch.max(torch.log(v))
            else:
                alpha, beta = alpha + reg * torch.log(u), beta + reg * torch.log(v)
                if nbb:
                    u, v = torch.ones((na, nbb)) / na, torch.ones((nb, nbb)) / nb
                else:
                    u, v = torch.ones(na) / na, torch.ones(nb) / nb
                if cuda:
                    u, v = u.cuda(), v.cuda()
            K = get_K(alpha, beta)

        if cpt % print_period == 0:
            # we can speed up the process by checking for the error only all
            # the 10th iterations
            if nbb:
                err = torch.sum((u - uprev)**2) / torch.sum((u)**2) + \
                    torch.sum((v - vprev)**2) / torch.sum((v)**2)
            else:
                transp = get_Gamma(alpha, beta, u, v)
                err = torch.norm((torch.sum(transp, dim=0) - b))**2
            if log:
                log['err'].append(err)

            if verbose:
                if cpt % (print_period * 20) == 0:
                    print(
                        '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
                print('{:5d}|{:8e}|'.format(cpt, err))

        if err <= stopThr:
            loop = False

        if cpt >= numItermax:
            loop = False

        if torch.sum((u != u) == 1) > 0 or torch.sum((v != v) == 1) > 0:
            # we have reached the machine precision
            # come back to previous solution and quit loop
            print('Warning: numerical errors at iteration', cpt)
            u = uprev
            v = vprev
            break

        cpt = cpt + 1

    # print('err=',err,' cpt=',cpt)
    if log:
        log['logu'] = alpha / reg + torch.log(u)
        log['logv'] = beta / reg + torch.log(v)
        log['alpha'] = alpha + reg * torch.log(u)
        log['beta'] = beta + reg * torch.log(v)
        log['warmstart'] = (log['alpha'], log['beta'])
        if nbb:
            res = torch.zeros((nbb))
            for i in range(nbb):
                res[i] = torch.sum(get_Gamma(alpha, beta, u[:, i], v[:, i])
                                   * M)
            return res, log

        else:
            return get_Gamma(alpha, beta, u, v), log
    else:
        if nbb:
            res = torch.zeros((nbb))
            for i in range(nbb):
                res[i] = torch.sum(get_Gamma(alpha, beta, u[:, i], v[:, i])
                                   * M)
            return res
        else:
            return get_Gamma(alpha, beta, u, v)