import math
import torch


class SinkhornPadded(torch.autograd.Function):
    """Wasserstein distance with entropy regularization.

    See Cuturi 2013, Sinkhorn Distances - Lightspeed Computation of Optimal Transport
    and Alg. 1 in Frogner 2015, Learning with a Wasserstein Loss.
    In our case h(x) = y = 1.

    Also inspired by Mocha.jl (https://github.com/pluskid/Mocha.jl/blob/master/src/layers/wasserstein-loss.jl)
    and the Python Optimal Transport library (https://github.com/rflamary/POT/blob/master/ot/bregman.py).

    reg = 1 / lambda presents a trade-off: Smaller values are closer to the true EMD but converge more slowly and are less stable

    """
    @staticmethod
    def forward(
            ctx, cost_mat: torch.FloatTensor,
            nnodes: torch.LongTensor,
            reg: torch.FloatTensor,
            niter: int = 50):
        batch_size, max_nodes, _ = cost_mat.shape
        K = torch.exp(-cost_mat / reg[:, None, None])
        if K.min() < 1e-13:
            raise RuntimeError(f"Sinkhorn reg too small. Smallest K = {K.min()}")
        u = cost_mat.new_full([batch_size, max_nodes], 1 / max_nodes)

        # Fill padding with 0
        mask_outer1 = (torch.arange(max_nodes, dtype=torch.float32, device=cost_mat.device)[:, None].expand_as(cost_mat)
                       >= nnodes[:, None, None])
        mask_outer2 = (torch.arange(max_nodes, dtype=torch.float32, device=cost_mat.device).expand_as(cost_mat)
                       >= nnodes[:, None, None])
        mask_outer = mask_outer1 | mask_outer2
        K = K.masked_fill(mask_outer, 0)

        # Mask that acts only on K's first column
        mask_vector = (torch.arange(max_nodes, dtype=torch.float32, device=cost_mat.device).expand(batch_size, max_nodes)
                       >= nnodes[:, None]).float()
        K_u = K.clone().transpose(1, 2)
        K_u[:, :, 0] += mask_vector

        K_v = K.clone()
        K_v[:, :, 0] += mask_vector

        for _ in range(niter):
            v = 1 / torch.einsum("bij, bj -> bi", K_u, u)
            u = 1 / torch.einsum("bij, bj -> bi", K_v, v)

        if (torch.isnan(u).any()
                or torch.isnan(v).any()
                or u.max() > 1e10
                or v.max() > 1e10):
            raise RuntimeError(f"Excessively large/nan values: u.max = {u.max():.2e}, v.max = {v.max():.2e}")

        T = torch.diag_embed(u) @ K @ torch.diag_embed(v)

        ctx.save_for_backward(T)

        return torch.sum(cost_mat * T, dim=[1, 2])

    @staticmethod
    def backward(ctx, grad_output):
        T, = ctx.saved_tensors
        return T * grad_output[:, None, None], None, None, None, None


class LogSinkhornPadded(torch.autograd.Function):
    """Wasserstein distance with entropy regularization, calculated in log space.

    See Peyré, SinkhornAutoDiff (https://github.com/gpeyre/SinkhornAutoDiff/blob/master/sinkhorn_pointcloud.py)
    and Daza, Approximating Wasserstein distances with PyTorch (https://dfdazac.github.io/sinkhorn.html)
    In our case mu = nu = 1.

    reg = 1 / lambda presents a trade-off: Smaller values are closer to the true EMD but converge more slowly and are less stable

    """
    @staticmethod
    def forward(
            ctx, cost_mat: torch.FloatTensor,
            nnodes: torch.LongTensor,
            reg: torch.FloatTensor,
            niter: int = 50,
            offset_entropy: bool = True):
        batch_size, max_nodes, _ = cost_mat.shape

        # Fill padding with inf
        mask_outer1 = (torch.arange(max_nodes, dtype=torch.int64, device=cost_mat.device)[:, None].expand_as(cost_mat)
                       >= nnodes[:, None, None])
        mask_outer2 = (torch.arange(max_nodes, dtype=torch.int64, device=cost_mat.device).expand_as(cost_mat)
                       >= nnodes[:, None, None])
        mask_outer = mask_outer1 | mask_outer2
        cost_inf = cost_mat.masked_fill(mask_outer, math.inf) / reg[:, None, None]

        def M(u, v):
            "Modified cost for logarithmic updates"
            "$M_{ij} = -c_{ij} + u_i + v_j$"
            # clamp to prevent NaN for inf - inf
            if u is None:
                return torch.clamp(v[:, None, :], max=1e10) - cost_inf
            elif v is None:
                return torch.clamp(u[:, :, None], max=1e10) - cost_inf
            else:
                return torch.clamp(u[:, :, None] + v[:, None, :], max=1e10) - cost_inf

        u = cost_mat.new_zeros(batch_size, max_nodes)
        v = cost_mat.new_zeros(batch_size, max_nodes)
        for _ in range(niter):
            u = -torch.logsumexp(M(None, v), dim=-1)
            v = -torch.logsumexp(M(u, None), dim=-2)

        T = torch.exp(M(u, v))

        ctx.save_for_backward(T)

        if offset_entropy:
            return reg * ((torch.clamp(u, max=1e10) * T.sum(2)).sum(1) + (torch.clamp(v, max=1e10) * T.sum(1)).sum(1))
        else:
            return torch.sum(torch.clamp(cost_mat, max=1e10) * T, dim=[1, 2])

    @staticmethod
    def backward(ctx, grad_output):
        T, = ctx.saved_tensors
        return T * grad_output[:, None, None], None, None, None, None


class LogSinkhornPaddedRect(torch.autograd.Function):
    """Wasserstein distance with entropy regularization, calculated in log space.

    See Peyré, SinkhornAutoDiff (https://github.com/gpeyre/SinkhornAutoDiff/blob/master/sinkhorn_pointcloud.py)
    and Daza, Approximating Wasserstein distances with PyTorch (https://dfdazac.github.io/sinkhorn.html)
    In our case mu = nu = 1.

    reg = 1 / lambda presents a trade-off: Smaller values are closer to the true EMD but converge more slowly and are less stable

    """
    @staticmethod
    def forward(
            ctx, cost_mat: torch.FloatTensor,
            nnodes: torch.LongTensor,
            reg: torch.FloatTensor,
            niter: int = 50,
            offset_entropy: bool = True):
        batch_size, max_nodes, _ = cost_mat.shape

        # Fill padding with inf
        mask_outer1 = (torch.arange(max_nodes, dtype=torch.int64, device=cost_mat.device)[:, None].expand_as(cost_mat)
                       >= nnodes[0, :, None, None])
        mask_outer2 = (torch.arange(max_nodes, dtype=torch.int64, device=cost_mat.device).expand_as(cost_mat)
                       >= nnodes[1, :, None, None])
        mask_outer = mask_outer1 | mask_outer2
        cost_inf = cost_mat.masked_fill(mask_outer, math.inf) / reg[:, None, None]

        def M(u, v):
            "Modified cost for logarithmic updates"
            "$M_{ij} = -c_{ij} + u_i + v_j$"
            # clamp to prevent NaN for inf - inf
            if u is None:
                return torch.clamp(v[:, None, :], max=1e10) - cost_inf
            elif v is None:
                return torch.clamp(u[:, :, None], max=1e10) - cost_inf
            else:
                return torch.clamp(u[:, :, None] + v[:, None, :], max=1e10) - cost_inf

        u = cost_mat.new_zeros(batch_size, max_nodes)
        v = cost_mat.new_zeros(batch_size, max_nodes)
        for _ in range(niter):
            u = -torch.logsumexp(M(None, v), dim=-1)
            v = -torch.logsumexp(M(u, None), dim=-2)

        T = torch.exp(M(u, v))

        ctx.save_for_backward(T)

        if offset_entropy:
            return reg * ((torch.clamp(u, max=1e10) * T.sum(2)).sum(1) + (torch.clamp(v, max=1e10) * T.sum(1)).sum(1))
        else:
            return torch.sum(cost_mat * T, dim=[1, 2])

    @staticmethod
    def backward(ctx, grad_output):
        T, = ctx.saved_tensors
        return T * grad_output[:, None, None], None, None, None, None


@torch.jit.script
def argSinkhornPadded(
        cost_mat: torch.Tensor,
        nnodes: torch.Tensor,
        reg: torch.Tensor,
        niter: int = 50):
    """Transport matrix according to Wasserstein distance with entropy regularization, calculated in log space.
    """
    batch_size, max_nodes, _ = cost_mat.shape

    # Fill padding with inf
    mask_outer1 = (torch.arange(max_nodes, dtype=torch.int64, device=cost_mat.device)[:, None].expand_as(cost_mat)
                   >= nnodes[:, None, None])
    mask_outer2 = (torch.arange(max_nodes, dtype=torch.int64, device=cost_mat.device).expand_as(cost_mat)
                   >= nnodes[:, None, None])
    mask_outer = mask_outer1 | mask_outer2
    cost_inf = cost_mat.masked_fill(mask_outer, math.inf)

    u = cost_mat.new_zeros(batch_size, max_nodes)
    v = cost_mat.new_zeros(batch_size, max_nodes)
    for _ in range(niter):
        u = -reg[:, None] * torch.logsumexp(
                torch.clamp(v[:, None, :] - cost_inf, min=-1e10) / reg[:, None, None], dim=-1)
        v = -reg[:, None] * torch.logsumexp(
                torch.clamp(u[:, :, None] - cost_inf, min=-1e10) / reg[:, None, None], dim=-2)

    T = torch.exp(torch.clamp(u[:, :, None] + v[:, None, :] - cost_inf, min=-1e10)
                  / reg[:, None, None])
    return T
