import torch
from torch_scatter.composite import scatter_logsumexp
from utils import segment_coo


class LogSinkhorn(torch.autograd.Function):
    """Wasserstein distance with entropy regularization, calculated in log space.

    See Peyré, SinkhornAutoDiff (https://github.com/gpeyre/SinkhornAutoDiff/blob/master/sinkhorn_pointcloud.py)
    and Daza, Approximating Wasserstein distances with PyTorch (https://dfdazac.github.io/sinkhorn.html)
    In our case mu = nu = 1.

    reg = 1 / lambda presents a trade-off: Smaller values are closer to the true EMD but converge more slowly and are less stable

    """
    @staticmethod
    def forward(
            ctx,
            cost_mat: torch.FloatTensor,
            cost_idx: torch.LongTensor,
            nnodes: torch.LongTensor,
            reg: torch.FloatTensor,
            batch_idx: torch.LongTensor,
            niter: int = 50):
        batch_size = nnodes.shape[-1]
        nnodes_sum = nnodes.sum()

        cost_scaled = cost_mat / reg[batch_idx]

        def M(u, v):
            "Modified cost for logarithmic updates"
            "$M_{ij} = -c_{ij} + u_i + v_j$"
            if u is None:
                return v[cost_idx[1]] - cost_scaled
            elif v is None:
                return u[cost_idx[0]] - cost_scaled
            else:
                return (u[cost_idx[0]] + v[cost_idx[1]] - cost_scaled)

        u = cost_mat.new_zeros(nnodes_sum)
        v = cost_mat.new_zeros(nnodes_sum)
        for _ in range(niter):
            u = -scatter_logsumexp(M(None, v), cost_idx[0], dim_size=nnodes_sum)
            v = -scatter_logsumexp(M(u, None), cost_idx[1], dim_size=nnodes_sum)

        T = torch.exp(M(u, v))
        ctx.save_for_backward(T, batch_idx)

        return segment_coo(cost_mat * T, batch_idx, dim_size=batch_size, reduce='sum')

    @staticmethod
    def backward(ctx, grad_output):
        T, batch_idx = ctx.saved_tensors
        return (T * grad_output[batch_idx],
                None, None, None, None, None)


def argSinkhorn(
        cost_mat: torch.FloatTensor,
        cost_idx: torch.LongTensor,
        nnodes: torch.LongTensor,
        reg: torch.FloatTensor,
        batch_idx: torch.LongTensor,
        niter: int = 50):
    """Transport matrix according to Wasserstein distance with entropy regularization, calculated in log space.
    """
    nnodes_sum = nnodes.sum()

    cost_scaled = cost_mat / reg[batch_idx]

    def M(u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = -c_{ij} + u_i + v_j$"
        if u is None:
            return v[cost_idx[1]] - cost_scaled
        elif v is None:
            return u[cost_idx[0]] - cost_scaled
        else:
            return (u[cost_idx[0]] + v[cost_idx[1]] - cost_scaled)

    u = cost_mat.new_zeros(nnodes_sum)
    v = cost_mat.new_zeros(nnodes_sum)
    for _ in range(niter):
        u = -scatter_logsumexp(M(None, v), cost_idx[0], dim_size=nnodes_sum)
        v = -scatter_logsumexp(M(u, None), cost_idx[1], dim_size=nnodes_sum)

    T = torch.exp(M(u, v))
    return T
