import torch

def log_selfsink(C: torch.Tensor, 
                eps: float=1.,
                f: torch.Tensor=None,
                tol: float=1e-5,
                max_iter: int=1000,
                student: bool=False,
                tolog: bool=False):
    """ 
        Performs Sinkhorn iterations in log domain to solve the entropic "self" (or "symmetric") OT problem with symmetric cost C and entropic regularization epsilon.
        Returns the transport plan and dual variable at convergence.

        Parameters
        ----------
        C: array (n,n)
            symmetric distance matrix
        eps: float
            entropic regularization coefficient
        f: array(n)
            initial dual variable
        tol: float
            precision threshold at which the algorithm stops
        max_iter: int
            maximum number of Sinkhorn iterations
        student: bool
            if True, a Student-t kernel is considered instead of Gaussian
        tolog: bool
            if True, log and returns intermediate variables
    """
    n = C.shape[0]

    # Allows a warm-start if a dual variable f is provided
    f = torch.zeros(n) if f is None else f.clone().detach()

    if tolog:
        log = {}
        log['f'] = [f.clone().detach()]

    # If student is True, considers the Student-t kernel instead of Gaussian
    if student:
        C = torch.log(1+C)

    # Sinkhorn iterations
    for k in range(max_iter+1):
        f = 0.5 * (f - eps*torch.logsumexp((f - C) / eps,-1))

        if tolog:
            log['f'].append(f.clone().detach())        

        if torch.isnan(f).any():
            raise Exception(f'NaN in self-Sinkhorn dual variable at iteration {k}')

        log_T = (f[:,None] + f[None,:] - C) / eps
        if (torch.abs(torch.exp(torch.logsumexp(log_T, -1))-1) < tol).all():
            break

        if k == max_iter-1:
            print('---------- Max iter attained ----------')

    if tolog:
        return (f[:,None] + f[None,:] - C) / eps, f, log
    else:
        return (f[:,None] + f[None,:] - C) / eps, f