import torch
from torch_scatter.composite import scatter_logsumexp
from utils import logsumexp_signed_signed, logdiffexp, scatter, segment_coo


class LogNystromSinkhornBPdiag(torch.autograd.Function):
    """Wasserstein distance with entropy regularization using a BP matrix, calculated in log space.

    See Peyré, SinkhornAutoDiff (https://github.com/gpeyre/SinkhornAutoDiff/blob/master/sinkhorn_pointcloud.py)
    and Daza, Approximating Wasserstein distances with PyTorch (https://dfdazac.github.io/sinkhorn.html)
    In our case mu = nu = 1.

    reg = 1 / lambda presents a trade-off: Smaller values are closer to the true EMD but converge more slowly and are less stable

    """

    @staticmethod
    @torch.jit.script
    def _lse_uv_inner(cost_1a, sim_a2, sign_a2, norms, vec, dim: int):
        _, max_nodes, _ = cost_1a.shape

        if dim % 3 == 2:
            mat_inner1, sign_inner1 = logsumexp_signed_signed(sim_a2 + vec[:, None, :max_nodes],
                                                              sign_a2, dim=2)
            mat_inner = (-cost_1a + mat_inner1[:, None, :])
            sign_inner = sign_inner1[:, None, :]
        else:
            mat_inner1 = torch.logsumexp(vec[:, :max_nodes, None] - cost_1a, dim=1)
            mat_inner = (mat_inner1[:, :, None] + sim_a2)
            sign_inner = sign_a2

        max_inner = mat_inner.max(dim).values
        max_outer = vec[:, max_nodes:] - norms
        lse_offset = torch.max(max_inner, max_outer)

        sum_inner = (sign_inner * torch.exp(mat_inner - lse_offset.unsqueeze(dim))).sum(dim)
        sum_outer = torch.exp(-norms - lse_offset + vec[:, max_nodes:])

        return torch.log(sum_inner + sum_outer) + lse_offset

    @staticmethod
    @torch.jit.script
    def _lse_uv_outer(norms: torch.Tensor, vec: torch.Tensor):
        _, max_nodes = norms.shape

        inner_part = -norms + vec[:, :max_nodes]
        logsum_outer = torch.logsumexp(vec[:, max_nodes:], dim=-1)

        lse_offset = torch.max(inner_part, logsum_outer[:, None])

        sum_inner = (inner_part - lse_offset).exp()
        sum_outer = (logsum_outer[:, None] - lse_offset).exp()

        return torch.log(sum_inner + sum_outer) + lse_offset

    @staticmethod
    def forward(
            ctx,
            cost_1a: torch.FloatTensor,
            sim_a2: torch.FloatTensor,
            sign_a2: torch.FloatTensor,
            norms1: torch.FloatTensor,
            norms2: torch.FloatTensor,
            nnodes: torch.LongTensor,
            reg: torch.FloatTensor,
            niter: int = 50):
        batch_size, max_nodes = norms1.shape

        cost_1a_scaled = cost_1a / reg[:, None, None]
        norms1_scaled = norms1 / reg[:, None]
        norms2_scaled = norms2 / reg[:, None]

        cost_1a_clamped = torch.clamp(cost_1a_scaled, max=1e20)
        norms1_clamped = torch.clamp(norms1_scaled, max=1e20)
        norms2_clamped = torch.clamp(norms2_scaled, max=1e20)

        mask_n1 = (torch.arange(max_nodes, dtype=torch.int64,
                                device=norms1.device).expand_as(norms1)
                   >= nnodes[0, :, None])
        mask_n2 = (torch.arange(max_nodes, dtype=torch.int64,
                                device=norms2.device).expand_as(norms2)
                   >= nnodes[1, :, None])
        mask_u = torch.cat((mask_n1, mask_n2), dim=1)
        mask_v = torch.cat((mask_n2, mask_n1), dim=1)

        u = norms1.new_zeros(batch_size, 2 * max_nodes)
        v = norms1.new_zeros(batch_size, 2 * max_nodes)
        for i in range(niter):
            u[:, :max_nodes] = -LogNystromSinkhornBPdiag._lse_uv_inner(cost_1a_clamped, sim_a2, sign_a2,
                                                                       norms1_clamped, v, dim=2)
            u[:, max_nodes:] = -LogNystromSinkhornBPdiag._lse_uv_outer(norms2_clamped, v)
            u.masked_fill_(mask_u, -1e10)
            v[:, :max_nodes] = -LogNystromSinkhornBPdiag._lse_uv_inner(cost_1a_clamped, sim_a2, sign_a2,
                                                                       norms2_clamped, u, dim=1)
            v[:, max_nodes:] = -LogNystromSinkhornBPdiag._lse_uv_outer(norms1_clamped, u)
            v.masked_fill_(mask_v, -1e10)

        T12 = torch.exp(-norms1_scaled + u[:, :max_nodes] + v[:, max_nodes:])
        T21 = torch.exp(-norms2_scaled + u[:, max_nodes:] + v[:, :max_nodes])
        T22_u = torch.exp(u[:, max_nodes:] + v[:, max_nodes:].logsumexp(-1)[:, None])
        T22_v = torch.exp(u[:, max_nodes:].logsumexp(-1)[:, None] + v[:, max_nodes:])

        T1a_log = -cost_1a_scaled + u[:, :max_nodes, None]
        Ta1_log = sim_a2 + v[:, None, :max_nodes]

        Ta1_logsum, Ta1_sum_sign = logsumexp_signed_signed(Ta1_log, sign_a2, dim=2)
        T1a_logsum = torch.logsumexp(T1a_log, dim=1)

        T11_sumright = Ta1_sum_sign[:, None, :] * (T1a_log + Ta1_logsum[:, None, :]).exp()
        T11_sumleft = sign_a2 * (T1a_logsum[:, :, None] + Ta1_log).exp()

        C11 = reg * ((u[:, :max_nodes] * T11_sumright.sum(2)).sum(1)
                        + (v[:, :max_nodes] * T11_sumleft.sum(1)).sum(1))
        C12 = reg * ((u[:, :max_nodes] * T12).sum(1) + (v[:, max_nodes:] * T12).sum(1))
        C21 = reg * ((u[:, max_nodes:] * T21).sum(1) + (v[:, :max_nodes] * T21).sum(1))
        C22 = reg * ((u[:, max_nodes:] * T22_u).sum(1) + (v[:, max_nodes:] * T22_v).sum(1))
        C = C11 + C12 + C21 + C22

        ctx.save_for_backward(T11_sumleft, T11_sumright, T12, T21, reg)

        # if torch.any(~torch.isfinite(C)):
        #     breakpoint()

        return C

    @staticmethod
    def backward(ctx, grad_output):
        T11_sumleft, T11_sumright, T12, T21, reg = ctx.saved_tensors

        return (T11_sumright * grad_output[:, None, None],
                -reg[:, None, None] * T11_sumleft * grad_output[:, None, None],
                None,
                T12 * grad_output[:, None],
                T21 * grad_output[:, None],
                None, None, None)


class LogSparseSinkhornBPdiag(torch.autograd.Function):
    """Wasserstein distance with entropy regularization, calculated in log space.

    See Peyré, SinkhornAutoDiff (https://github.com/gpeyre/SinkhornAutoDiff/blob/master/sinkhorn_pointcloud.py)
    and Daza, Approximating Wasserstein distances with PyTorch (https://dfdazac.github.io/sinkhorn.html)
    In our case mu = nu = 1.

    reg = 1 / lambda presents a trade-off: Smaller values are closer to the true EMD but converge more slowly and are less stable

    """

    @staticmethod
    @torch.jit.script
    def _lse_uv_inner(cost_mat, cost_vec_idx, cost_scatter_idx,
                      norms, norms_outer_batch_idx, vec, batch_size: int):
        outer_start = vec.shape[0] - norms.shape[0]

        mat_inner = -cost_mat + vec[cost_vec_idx]

        max_inner = scatter(mat_inner, cost_scatter_idx, dim_size=norms_outer_batch_idx.shape[0], reduce='max')
        max_outer = vec[outer_start:] - norms
        lse_offset = torch.max(max_inner, max_outer)

        sum_inner = scatter((mat_inner - lse_offset[cost_scatter_idx]).exp(),
                            cost_scatter_idx, dim_size=norms_outer_batch_idx.shape[0], reduce='sum')

        sum_outer = torch.exp(-norms - lse_offset + vec[outer_start:])

        return torch.log(sum_inner + sum_outer) + lse_offset

    @staticmethod
    @torch.jit.script
    def _lse_uv_outer(norms, inner_batch_idx, outer_batch_idx, vec, batch_size: int):
        outer_start = inner_batch_idx.shape[0]

        inner_part = -norms + vec[:outer_start]
        logsum_outer = scatter_logsumexp(vec[outer_start:], outer_batch_idx, dim_size=batch_size)

        lse_offset = torch.max(inner_part, logsum_outer[inner_batch_idx])

        sum_inner = (inner_part - lse_offset).exp()
        sum_outer = (logsum_outer[inner_batch_idx] - lse_offset).exp()

        return torch.log(sum_inner + sum_outer) + lse_offset[inner_batch_idx]

    @staticmethod
    def forward(
            ctx,
            cost_mat: torch.FloatTensor,
            cost_batch_idx: torch.LongTensor,
            cost_idx1: torch.LongTensor,
            cost_idx2: torch.LongTensor,
            norms1: torch.FloatTensor,
            norms1_batch_idx: torch.LongTensor,
            norms2: torch.FloatTensor,
            norms2_batch_idx: torch.LongTensor,
            nnodes: torch.LongTensor,
            reg: torch.FloatTensor,
            niter: int = 50):
        batch_size = nnodes.shape[1]
        nnodes_sums = nnodes.sum(1)

        cost_scaled = cost_mat / reg[cost_batch_idx]
        norms1_scaled = norms1 / reg[norms1_batch_idx]
        norms2_scaled = norms2 / reg[norms2_batch_idx]

        u = cost_mat.new_zeros(nnodes.sum())
        v = cost_mat.new_zeros(nnodes.sum())
        for _ in range(niter):
            u[:nnodes_sums[0]] = -LogSparseSinkhornBPdiag._lse_uv_inner(
                cost_scaled, cost_idx2, cost_idx1,
                norms1_scaled, norms1_batch_idx, v, batch_size)
            u[nnodes_sums[0]:] = -LogSparseSinkhornBPdiag._lse_uv_outer(
                norms2_scaled, norms2_batch_idx, norms1_batch_idx,
                v, batch_size)
            v[:nnodes_sums[1]] = -LogSparseSinkhornBPdiag._lse_uv_inner(
                cost_scaled, cost_idx1, cost_idx2,
                norms2_scaled, norms2_batch_idx, u, batch_size)
            v[nnodes_sums[1]:] = -LogSparseSinkhornBPdiag._lse_uv_outer(
                norms1_scaled, norms1_batch_idx, norms2_batch_idx,
                u, batch_size)

        T11 = (-cost_scaled + u[cost_idx1] + v[cost_idx2]).exp()
        T12 = torch.exp(-norms1_scaled + u[:nnodes_sums[0]] + v[nnodes_sums[1]:])
        T21 = torch.exp(-norms2_scaled + v[:nnodes_sums[1]] + u[nnodes_sums[0]:])

        ctx.save_for_backward(T11, T12, T21, cost_batch_idx, norms1_batch_idx, norms2_batch_idx)

        T22_u = torch.exp(
                u[nnodes_sums[0]:]
                + scatter_logsumexp(v[nnodes_sums[1]:], norms1_batch_idx, dim_size=batch_size)[norms2_batch_idx])
        T22_v = torch.exp(
                scatter_logsumexp(u[nnodes_sums[0]:], norms2_batch_idx, dim_size=batch_size)[norms1_batch_idx]
                + v[nnodes_sums[1]:])

        C11 = reg * segment_coo((u[cost_idx1] + v[cost_idx2]) * T11, cost_batch_idx, dim_size=batch_size, reduce='sum')
        C12 = reg * (segment_coo(u[:nnodes_sums[0]] * T12, norms1_batch_idx, dim_size=batch_size, reduce='sum')
                     + segment_coo(v[nnodes_sums[1]:] * T12, norms1_batch_idx, dim_size=batch_size, reduce='sum'))
        C21 = reg * (segment_coo(u[nnodes_sums[0]:] * T21, norms2_batch_idx, dim_size=batch_size, reduce='sum')
                     + segment_coo(v[:nnodes_sums[1]] * T21, norms2_batch_idx, dim_size=batch_size, reduce='sum'))
        C22 = reg * (segment_coo(u[nnodes_sums[0]:] * T22_u, norms2_batch_idx, dim_size=batch_size, reduce='sum')
                     + segment_coo(v[nnodes_sums[1]:] * T22_v, norms1_batch_idx, dim_size=batch_size, reduce='sum'))
        C = C11 + C12 + C21 + C22

        # if torch.any(~torch.isfinite(T11)):
        #     breakpoint()

        return C

    @staticmethod
    def backward(ctx, grad_output):
        T11, T12, T21, cost_batch_idx, norms1_batch_idx, norms2_batch_idx = ctx.saved_tensors

        return (T11 * grad_output[cost_batch_idx], None, None, None,
                T12 * grad_output[norms1_batch_idx], None,
                T21 * grad_output[norms2_batch_idx], None,
                None, None, None)


class LogSparseNystromSinkhornBPdiag(torch.autograd.Function):
    """Wasserstein distance with entropy regularization using a BP matrix, calculated in log space.

    See Peyré, SinkhornAutoDiff (https://github.com/gpeyre/SinkhornAutoDiff/blob/master/sinkhorn_pointcloud.py)
    and Daza, Approximating Wasserstein distances with PyTorch (https://dfdazac.github.io/sinkhorn.html)
    In our case mu = nu = 1.

    reg = 1 / lambda presents a trade-off: Smaller values are closer to the true EMD but converge more slowly and are less stable

    """

    @staticmethod
    @torch.jit.script
    def _lse_uv_inner(cost_1a, sim_a2, sign_a2, sim_corr, corr_sign,
                      corr_batch_idx, corr_idx1, corr_idx2, nodes_idx1, nodes_idx2,
                      norms1_batch_idx, norms1_idx, norms2_batch_idx, norms2_idx,
                      norms, vec, dim: int):
        _, max_nodes, _ = cost_1a.shape

        if dim % 3 == 2:
            mat_inner1, sign_inner1 = logsumexp_signed_signed(sim_a2 + vec[:, None, :max_nodes],
                                                              sign_a2, dim=2)
            mat_inner = (-cost_1a + mat_inner1[:, None, :])
            sign_inner = sign_inner1[:, None, :]
        else:
            mat_inner1 = torch.logsumexp(vec[:, :max_nodes, None] - cost_1a, dim=1)
            mat_inner = (mat_inner1[:, :, None] + sim_a2)
            sign_inner = sign_a2

        max_inner = mat_inner.max(dim).values
        max_outer = vec[:, max_nodes:] - norms
        lse_offset = torch.max(max_inner, max_outer)

        if dim % 3 == 2:
            correction = sim_corr + vec[corr_batch_idx, nodes_idx2]
            max_corr = segment_coo(correction, corr_idx1, dim_size=norms1_batch_idx.shape[0], reduce='max')
            lse_offset[norms1_batch_idx, norms1_idx] = torch.max(max_corr, lse_offset[norms1_batch_idx, norms1_idx])
        else:
            correction = sim_corr + vec[corr_batch_idx, nodes_idx1]
            max_corr = scatter(correction, corr_idx2, dim_size=norms2_batch_idx.shape[0], reduce='max')
            lse_offset[norms2_batch_idx, norms2_idx] = torch.max(max_corr, lse_offset[norms2_batch_idx, norms2_idx])

        sum_inner = (sign_inner * torch.exp(mat_inner - lse_offset.unsqueeze(dim))).sum(dim)
        sum_outer = torch.exp(-norms - lse_offset + vec[:, max_nodes:])

        if dim % 3 == 2:
            corr_offset = segment_coo(corr_sign * torch.exp(correction - lse_offset[corr_batch_idx, nodes_idx1]),
                                      corr_idx1, dim_size=norms1_batch_idx.shape[0], reduce='sum')
            sum_inner[norms1_batch_idx, norms1_idx] += corr_offset
        else:
            corr_offset = scatter(corr_sign * torch.exp(correction - lse_offset[corr_batch_idx, nodes_idx2]),
                                  corr_idx2, dim_size=norms2_batch_idx.shape[0], reduce='sum')
            sum_inner[norms2_batch_idx, norms2_idx] += corr_offset

        return torch.log(sum_inner + sum_outer) + lse_offset

    @staticmethod
    @torch.jit.script
    def _lse_uv_outer(norms: torch.Tensor, vec: torch.Tensor):
        _, max_nodes = norms.shape

        inner_part = -norms + vec[:, :max_nodes]
        logsum_outer = torch.logsumexp(vec[:, max_nodes:], dim=-1)

        lse_offset = torch.max(inner_part, logsum_outer[:, None])

        sum_inner = (inner_part - lse_offset).exp()
        sum_outer = (logsum_outer[:, None] - lse_offset).exp()

        return torch.log(sum_inner + sum_outer) + lse_offset

    @staticmethod
    def forward(
            ctx,
            cost_1a: torch.FloatTensor,
            sim_a2: torch.FloatTensor,
            sign_a2: torch.FloatTensor,
            cost_exact: torch.FloatTensor,
            sim_approx: torch.FloatTensor,
            sign_approx: torch.FloatTensor,
            corr_batch_idx: torch.LongTensor,
            corr_idx1: torch.LongTensor,
            corr_idx2: torch.LongTensor,
            nodes_idx1: torch.LongTensor,
            nodes_idx2: torch.LongTensor,
            norms1_batch_idx: torch.LongTensor,
            norms1_idx: torch.LongTensor,
            norms2_batch_idx: torch.LongTensor,
            norms2_idx: torch.LongTensor,
            norms1: torch.FloatTensor,
            norms2: torch.FloatTensor,
            nnodes: torch.LongTensor,
            reg: torch.FloatTensor,
            niter: int = 50):
        batch_size, max_nodes = norms1.shape

        cost_1a_scaled = cost_1a / reg[:, None, None]
        norms1_scaled = norms1 / reg[:, None]
        norms2_scaled = norms2 / reg[:, None]

        sim_exact = -cost_exact / reg[corr_batch_idx]
        sim_corr, corr_sign = logdiffexp(sim_exact + 1e-40, sim_approx + 1e-40, sign_approx)

        mask_n1 = (torch.arange(max_nodes, dtype=torch.int64,
                                device=norms1.device).expand_as(norms1)
                   >= nnodes[0, :, None])
        mask_n2 = (torch.arange(max_nodes, dtype=torch.int64,
                                device=norms2.device).expand_as(norms2)
                   >= nnodes[1, :, None])
        mask_u = torch.cat((mask_n1, mask_n2), dim=1)
        mask_v = torch.cat((mask_n2, mask_n1), dim=1)

        u = norms1.new_zeros(batch_size, 2 * max_nodes)
        v = norms1.new_zeros(batch_size, 2 * max_nodes)
        for _ in range(niter):
            u[:, :max_nodes] = -LogSparseNystromSinkhornBPdiag._lse_uv_inner(
                    cost_1a_scaled, sim_a2, sign_a2, sim_corr, corr_sign, corr_batch_idx,
                    corr_idx1, corr_idx2, nodes_idx1, nodes_idx2,
                    norms1_batch_idx, norms1_idx, norms2_batch_idx, norms2_idx,
                    norms1_scaled, v, dim=2)
            u[:, max_nodes:] = -LogSparseNystromSinkhornBPdiag._lse_uv_outer(norms2_scaled, v)
            u.masked_fill_(mask_u, -1e10)
            v[:, :max_nodes] = -LogSparseNystromSinkhornBPdiag._lse_uv_inner(
                    cost_1a_scaled, sim_a2, sign_a2, sim_corr, corr_sign, corr_batch_idx,
                    corr_idx1, corr_idx2, nodes_idx1, nodes_idx2,
                    norms1_batch_idx, norms1_idx, norms2_batch_idx, norms2_idx,
                    norms2_scaled, u, dim=1)
            v[:, max_nodes:] = -LogSparseNystromSinkhornBPdiag._lse_uv_outer(norms1_scaled, u)
            v.masked_fill_(mask_v, -1e10)

        T12 = torch.exp(-norms1_scaled + u[:, :max_nodes] + v[:, max_nodes:])
        T21 = torch.exp(-norms2_scaled + u[:, max_nodes:] + v[:, :max_nodes])
        T22_u = torch.exp(u[:, max_nodes:] + v[:, max_nodes:].logsumexp(-1)[:, None])
        T22_v = torch.exp(u[:, max_nodes:].logsumexp(-1)[:, None] + v[:, max_nodes:])

        T1a_log = -cost_1a_scaled + u[:, :max_nodes, None]
        Ta1_log = sim_a2 + v[:, None, :max_nodes]

        Ta1_logsum, Ta1_sum_sign = logsumexp_signed_signed(Ta1_log, sign_a2, dim=2)
        T1a_logsum = torch.logsumexp(T1a_log, dim=1)

        T11_sumright = Ta1_sum_sign[:, None, :] * (T1a_log + Ta1_logsum[:, None, :]).exp()
        T11_sumleft = sign_a2 * (T1a_logsum[:, :, None] + Ta1_log).exp()

        C11 = reg * ((u[:, :max_nodes] * T11_sumright.sum(2)).sum(1)
                     + (v[:, :max_nodes] * T11_sumleft.sum(1)).sum(1))
        C12 = reg * ((u[:, :max_nodes] * T12).sum(1) + (v[:, max_nodes:] * T12).sum(1))
        C21 = reg * ((u[:, max_nodes:] * T21).sum(1) + (v[:, :max_nodes] * T21).sum(1))
        C22 = reg * ((u[:, max_nodes:] * T22_u).sum(1) + (v[:, max_nodes:] * T22_v).sum(1))
        C_nystrom = C11 + C12 + C21 + C22

        T11_exact = torch.exp(sim_exact + u[corr_batch_idx, nodes_idx1] + v[corr_batch_idx, nodes_idx2])
        T11_approx = torch.exp(sim_approx + u[corr_batch_idx, nodes_idx1] + v[corr_batch_idx, nodes_idx2])
        T11_delta = T11_exact - T11_approx
        C11_corr = reg * (segment_coo(u[norms1_batch_idx, norms1_idx]
                                      * segment_coo(T11_delta, corr_idx1, dim_size=norms1_batch_idx.shape[0], reduce='sum'),
                                      norms1_batch_idx, dim_size=batch_size, reduce='sum')
                          + segment_coo(v[norms2_batch_idx, norms2_idx]
                                        * scatter(T11_delta, corr_idx2, dim_size=norms2_batch_idx.shape[0], reduce='sum'),
                                        norms2_batch_idx, dim_size=batch_size, reduce='sum'))
        C = C_nystrom + C11_corr

        ctx.save_for_backward(T11_sumleft, T11_sumright, T11_exact, T11_approx, corr_batch_idx, T12, T21, reg)

        return C

    @staticmethod
    def backward(ctx, grad_output):
        T11_sumleft, T11_sumright, T11_exact, T11_approx, corr_batch_idx, T12, T21, reg = ctx.saved_tensors

        return (T11_sumright * grad_output[:, None, None],
                -reg[:, None, None] * T11_sumleft * grad_output[:, None, None],
                None,
                T11_exact * grad_output[corr_batch_idx],
                reg[corr_batch_idx] * T11_approx * grad_output[corr_batch_idx],
                None, None, None, None, None, None, None, None, None, None,
                T12 * grad_output[:, None],
                T21 * grad_output[:, None],
                None, None, None)
