import math
import torch

from matching.sinkhorn_padded import argSinkhornPadded
from utils.landmarks import calc_landmarks
from utils import loginvexp, logsumexp_signed_signed, scatter, repeat_blocks, distances, kernels


class DistMatrix():
    def __init__(
            self,
            # node_embeddings,
            m,
            dist_mat_len,
            num_nodes,
            sinkhorn_reg,
            # return_distance,
            dist_idx=None,
            norms1=None,
            norms2=None):
        # self.node_embeddings = node_embeddings
        self.m = m
        self.dist_mat_len = dist_mat_len
        self.num_nodes = num_nodes
        self.sinkhorn_reg = sinkhorn_reg
        # self.return_distance = return_distance
        self.dist_idx = dist_idx
        self.norms1 = norms1
        self.norms2 = norms2


def clone_distmatrix(distmatrix):
    return DistMatrix(
            distmatrix.m.clone(),
            distmatrix.dist_mat_len,
            distmatrix.num_nodes,
            distmatrix.sinkhorn_reg.clone(),
            None if distmatrix.dist_idx is None else distmatrix.dist_idx.clone(),
            None if distmatrix.norms1 is None else distmatrix.norms1.clone(),
            None if distmatrix.norms2 is None else distmatrix.norms2.clone(),
        )


def get_matrix_stack_idx(nnodes):
    idx1 = torch.arange(nnodes.sum(), device=nnodes.device)
    nnodes_rpt = torch.repeat_interleave(nnodes, nnodes)
    idx1 = torch.repeat_interleave(idx1, nnodes_rpt)
    idx2 = repeat_blocks(nnodes, nnodes)
    return torch.stack((idx1, idx2))


def calc_dist_matrix(node_embeddings, dist_idx, nnodes, dist_mat_len, dist, alpha):
    # Concatenate 0 embedding for square padding with norms
    emb_zero1 = torch.cat(
            (node_embeddings[0],
             torch.zeros(1, node_embeddings[0].shape[-1], device=node_embeddings[0].device)),
            dim=0)
    emb_zero2 = torch.cat(
            (node_embeddings[1],
             torch.zeros(1, node_embeddings[1].shape[-1], device=node_embeddings[1].device)),
            dim=0)

    csum_mat_len = torch.cumsum(dist_mat_len, dim=0)
    csum_nn = torch.cumsum(nnodes, dim=1)

    # For finding padding entries
    idx_max = torch.repeat_interleave(csum_nn - 1, dist_mat_len**2, dim=1)

    # For fixing mistakes due to only using dist_mat_len for generation
    idx_offset = torch.cat((torch.zeros(2, 1, dtype=torch.long, device=csum_nn.device),
                            (csum_mat_len[None, :] - csum_nn)[:, :-1]),
                           dim=1)
    idx_offset = torch.repeat_interleave(idx_offset, dist_mat_len**2, dim=1)

    # Fix indices by offsetting and setting overshooting indices to -1
    # -1 corresponds to using 0, i.e. the norm (for deletion)
    idx_corr = dist_idx - idx_offset
    idx_corr[0].masked_fill_(idx_corr[0] > idx_max[0], emb_zero1.shape[0] - 1)
    idx_corr[1].masked_fill_(idx_corr[1] > idx_max[1], emb_zero2.shape[0] - 1)

    emb_diff = torch.index_select(emb_zero1, dim=0, index=idx_corr[0])
    emb_diff -= torch.index_select(emb_zero2, dim=0, index=idx_corr[1])

    # Scale non-distances (embedding norms) by learnable alpha
    mask_nodist = (idx_corr[0] == emb_zero1.shape[0] - 1) | (idx_corr[1] == emb_zero2.shape[0] - 1)

    if alpha.shape[0] == 1:
        dist_matrix = dist.norm(emb_diff)
        dist_matrix = ((1 - alpha)**2 * mask_nodist) * dist_matrix
    else:
        emb_diff = ((1 - alpha[None, :]) * mask_nodist[:, None]) * emb_diff
        dist_matrix = dist.norm(emb_diff)
    # Distance matrix index: C_ij = i*n + j

    return dist_matrix


def calc_dist_matrix_padded(node_embeddings, nnodes, dist, alpha):
    _, max_nodes, _ = node_embeddings[0].shape

    # Real size: b x n1 x n2
    dist_matrix = dist.cdist(node_embeddings[0], node_embeddings[1])

    # This mask gives everything that is not a real calculated distance
    mask_n1 = (torch.arange(max_nodes, dtype=torch.int64,
                            device=nnodes.device)[:, None].expand_as(dist_matrix)
               >= nnodes[0, :, None, None])
    mask_n2 = (torch.arange(max_nodes, dtype=torch.int64,
                            device=nnodes.device).expand_as(dist_matrix)
               >= nnodes[1, :, None, None])
    mask_nodist = mask_n1 | mask_n2
    nll_mask = ~mask_nodist

    # This mask gives the symmetric matrix needed for node matching
    nnodes_max = nnodes.max(0).values
    mask_inner1 = (torch.arange(max_nodes, dtype=torch.int64,
                                device=nnodes.device)[:, None].expand_as(dist_matrix)
                   < nnodes_max[:, None, None])
    mask_inner2 = (torch.arange(max_nodes, dtype=torch.int64,
                                device=nnodes.device).expand_as(dist_matrix)
                   < nnodes_max[:, None, None])
    mask_inner = mask_inner1 & mask_inner2

    # Set all non-distances to 0
    dist_matrix = dist_matrix.masked_fill(mask_nodist, 0)

    # Fill non-distance values in inner matrix with norms of node embeddings (scaled by learnable alpha)
    if alpha.shape[0] == 1:
        fill_val1 = alpha**2 * dist.norm(node_embeddings[0])
        fill_val2 = alpha**2 * dist.norm(node_embeddings[1])
    else:
        fill_val1 = dist.norm(alpha[None, :] * node_embeddings[0])
        fill_val2 = dist.norm(alpha[None, :] * node_embeddings[1])
    dist_matrix += fill_val1[:, :, None] * (mask_inner & mask_n2)
    dist_matrix += fill_val2[:, None, :] * (mask_inner & mask_n1)

    return dist_matrix, nll_mask


def calc_dist_matrix_padded_rect(node_embeddings, nnodes, dist):
    _, max_nodes, _ = node_embeddings[0].shape

    # Real size: b x n1 x n2
    dist_matrix = dist.cdist(node_embeddings[0], node_embeddings[1])

    # This mask gives everything that is not a real calculated distance
    mask_n1 = (torch.arange(max_nodes, dtype=torch.int64,
                            device=nnodes.device)[:, None].expand_as(dist_matrix)
               >= nnodes[0, :, None, None])
    mask_n2 = (torch.arange(max_nodes, dtype=torch.int64,
                            device=nnodes.device).expand_as(dist_matrix)
               >= nnodes[1, :, None, None])
    mask_nodist = mask_n1 | mask_n2
    nll_mask = ~mask_nodist

    # Set all non-distances to 0
    dist_matrix = dist_matrix.masked_fill(mask_nodist, 0)

    return dist_matrix, nll_mask


def calc_bp_dist_matrix_padded(node_embeddings, nnodes, dist, alpha, diag=False):
    batch_size, max_nodes, _ = node_embeddings[0].shape

    # Inner size: b x n1 x n2
    max_n1n2 = nnodes.sum(0).max()
    dist_matrix = node_embeddings[0].new_zeros(batch_size, max_n1n2, max_n1n2)
    dist_matrix[:, :max_nodes, :max_nodes] = dist.cdist(node_embeddings[0], node_embeddings[1])

    # This mask gives everything that is a real calculated distance
    range_nodes = torch.arange(max_n1n2, dtype=torch.int64, device=nnodes.device)
    mask_in_n1 = (range_nodes[:, None].expand_as(dist_matrix) < nnodes[0, :, None, None])
    mask_in_n2 = (range_nodes.expand_as(dist_matrix) < nnodes[1, :, None, None])
    mask_dist = mask_in_n1 & mask_in_n2
    nll_mask = mask_dist

    # Set all non-distances to 0
    dist_matrix = dist_matrix.masked_fill(~mask_dist, 0)

    # This mask gives the matrix needed for node matching (without padding)
    nnodes_sum = nnodes.sum(0)
    mask_inner1 = (range_nodes[:, None].expand_as(dist_matrix) < nnodes_sum[:, None, None])
    mask_inner2 = (range_nodes.expand_as(dist_matrix) < nnodes_sum[:, None, None])
    mask_inner = mask_inner1 & mask_inner2

    # Fill non-distance values in inner matrix with norms of node embeddings (scaled by learnable alpha)
    if alpha.shape[0] == 1:
        fill_val1 = alpha**2 * dist.norm(node_embeddings[0])
        fill_val2 = alpha**2 * dist.norm(node_embeddings[1])
    else:
        fill_val1 = dist.norm(alpha[None, :] * node_embeddings[0])
        fill_val2 = dist.norm(alpha[None, :] * node_embeddings[1])

    if diag:
        mask_diag1 = (range_nodes.expand_as(dist_matrix) == nnodes[1][:, None, None] + range_nodes[:, None].expand_as(dist_matrix))
        mask_diag2 = (range_nodes[:, None].expand_as(dist_matrix) == nnodes[0][:, None, None] + range_nodes.expand_as(dist_matrix))
        dist_matrix[:, :max_nodes, :] += fill_val1[:, :, None] * (mask_inner & mask_in_n1 & ~mask_in_n2 & mask_diag1)[:, :max_nodes, :]
        dist_matrix[:, :max_nodes, :].masked_fill_((mask_inner & mask_in_n1 & ~mask_in_n2 & ~mask_diag1)[:, :max_nodes, :], math.inf)
        dist_matrix[:, :, :max_nodes] += fill_val2[:, None, :] * (mask_inner & ~mask_in_n1 & mask_in_n2 & mask_diag2)[:, :, :max_nodes]
        dist_matrix[:, :, :max_nodes].masked_fill_((mask_inner & ~mask_in_n1 & mask_in_n2 & ~mask_diag2)[:, :, :max_nodes], math.inf)
    else:
        dist_matrix[:, :max_nodes, :] += fill_val1[:, :, None] * (mask_inner & mask_in_n1 & ~mask_in_n2)[:, :max_nodes, :]
        dist_matrix[:, :, :max_nodes] += fill_val2[:, None, :] * (mask_inner & ~mask_in_n1 & mask_in_n2)[:, :, :max_nodes]

    return dist_matrix, nll_mask


@torch.no_grad()
def get_pair_indices_matched_clusters(nnodes, landmarks, nlandmarks, node_cluster, reg_scaled,
                                      dist, alpha, sinkhorn_niter, threshold, num_hashes):
    batch_size, max_nodes = node_cluster[0].shape

    # Get cluster sizes
    cluster_sizes1 = scatter(node_cluster[0].new_ones(1).expand_as(node_cluster[0]),
                             node_cluster[0], dim=1, dim_size=nlandmarks + 1, reduce='sum')[:, :nlandmarks].flatten()
    cluster_sizes2 = scatter(node_cluster[1].new_ones(1).expand_as(node_cluster[1]),
                             node_cluster[1], dim=1, dim_size=nlandmarks + 1, reduce='sum')[:, :nlandmarks].flatten()

    # Batch offsets for the relative node indices
    offsets_left = torch.cat((torch.zeros(1, dtype=torch.long, device=nnodes.device), nnodes[0, :-1].cumsum(0)))
    offsets_right = torch.cat((torch.zeros(1, dtype=torch.long, device=nnodes.device), nnodes[1, :-1].cumsum(0)))

    # Get nodes in each cluster
    batch_idx = torch.arange(batch_size, dtype=node_cluster[0].dtype, device=nnodes.device)[:, None].expand(-1, max_nodes)
    cl_nodes_node_left_all = torch.argsort(node_cluster[0])
    cl_nodes_node_left = cl_nodes_node_left_all[(node_cluster[0] < nlandmarks)[batch_idx, cl_nodes_node_left_all]]
    cl_nodes_node_right_all = torch.argsort(node_cluster[1])
    cl_nodes_node_right = cl_nodes_node_right_all[(node_cluster[1] < nlandmarks)[batch_idx, cl_nodes_node_right_all]]

    if threshold < 1:
        assert num_hashes == 1

        # Coarse scale: Matching between landmarks
        if isinstance(dist, kernels.Kernel):
            dist = distances.KernelDist(dist)
        dist_lm = dist.cdist(landmarks[0], landmarks[1])

        # Match landmarks
        dist_lm_len = dist_lm.new_tensor([nlandmarks])
        T_lm = argSinkhornPadded(dist_lm, dist_lm_len, reg_scaled, sinkhorn_niter)

        choice = (T_lm > threshold)

        # Indices for fine scale
        # This can probably be done much simpler via argsort.

        # Cluster pairs
        cl_pairs_batch, cl_pairs_left, cl_pairs_right = torch.where(choice)
        cl_pairs_offset_left = nlandmarks * cl_pairs_batch + cl_pairs_left
        cl_pairs_offset_right = nlandmarks * cl_pairs_batch + cl_pairs_right

        # Number of right nodes in each cluster pair
        cl_pair_sizes2 = cluster_sizes2[cl_pairs_offset_right]

        # Number of copies per cluster
        cl_reps_left = scatter(cl_pair_sizes2, cl_pairs_offset_left, dim_size=batch_size * nlandmarks, reduce='sum')
        node_reps_left = torch.repeat_interleave(cl_reps_left, cluster_sizes1)

        # Left indices of node pairs
        nodes_left = torch.repeat_interleave(cl_nodes_node_left, node_reps_left)

        # Relative indices of right nodes in each cluster pair
        # We construct these by misusing repeat_blocks to give indices of present cluster pairs
        # and offset them by the number of nodes in non-present pairs.
        cluster_sizes2_expanded = cluster_sizes2.reshape(batch_size, 1, nlandmarks).expand(batch_size, nlandmarks, nlandmarks).flatten()
        nodes_right_cl_idx_raw = repeat_blocks(cluster_sizes2_expanded, choice.flatten().long())

        # Node offsets for the relative right node indices
        node_offsets = torch.cat((torch.zeros(1, dtype=torch.long, device=nnodes.device),
                                  nnodes[1, :, None].expand(batch_size, nlandmarks).flatten()[:-1].cumsum(0)))

        # Relative indices of right nodes with correct offset
        offsets_right_cl_idx = offsets_right[:, None].expand(batch_size, nlandmarks).flatten() - node_offsets
        offsets_right_cl_idx_repeated = torch.repeat_interleave(offsets_right_cl_idx, cl_reps_left)
        nodes_right_cl_idx = nodes_right_cl_idx_raw + offsets_right_cl_idx_repeated

        # Indices of right nodes in each cluster pair
        nodes_right_cl = cl_nodes_node_right[nodes_right_cl_idx]

        # Right indices of node pairs
        nodes_right_idx = repeat_blocks(cl_reps_left, cluster_sizes1)
        nodes_right = nodes_right_cl[nodes_right_idx]

        # Batch index of each node pair
        pairs_per_sample = (cl_reps_left * cluster_sizes1).reshape(batch_size, nlandmarks).sum(1)
    else:
        cluster_sizes2_rep = torch.repeat_interleave(cluster_sizes2, cluster_sizes1)
        nodes_left = torch.repeat_interleave(cl_nodes_node_left, cluster_sizes2_rep)

        nodes_right_idx = repeat_blocks(cluster_sizes2, cluster_sizes1)
        nodes_right = cl_nodes_node_right[nodes_right_idx]

        pairs_per_sample = (cluster_sizes1 * cluster_sizes2).reshape(batch_size, nlandmarks).sum(1)

    if num_hashes == 1:
        # Batch index of each node pair
        pairs_batch_idx = torch.repeat_interleave(torch.arange(batch_size, dtype=torch.long, device=nnodes.device),
                                                  pairs_per_sample)
        # Set up left distance matrix indices
        dists_idx_left = nodes_left + torch.repeat_interleave(offsets_left, pairs_per_sample)

        # Sort indices for efficiency
        resort_idx = torch.argsort(dists_idx_left)
        dists_idx_left = dists_idx_left[resort_idx]
        nodes_left = nodes_left[resort_idx]
        nodes_right = nodes_right[resort_idx]
    else:
        real_batch_size = batch_size // num_hashes

        # Batch index of each node pair
        batch_idx = torch.repeat_interleave(torch.arange(real_batch_size, dtype=torch.long, device=nnodes.device),
                                            num_hashes)
        pairs_batch_idx = torch.repeat_interleave(batch_idx, pairs_per_sample)

        pairs = torch.stack((pairs_batch_idx, nodes_left, nodes_right), dim=1)
        pairs_unique = torch.unique(pairs, sorted=True, dim=0)
        pairs_batch_idx, nodes_left, nodes_right = pairs_unique.unbind(dim=1)

        pairs_per_sample = torch.bincount(pairs_batch_idx, minlength=real_batch_size)

        # Set up left distance matrix indices
        dists_idx_left = nodes_left + torch.repeat_interleave(offsets_left, pairs_per_sample)

    # Set up right distance matrix indices
    dists_idx_right = nodes_right + torch.repeat_interleave(offsets_right, pairs_per_sample)

    return pairs_batch_idx, nodes_left, nodes_right, dists_idx_left, dists_idx_right


def calc_sparse_dist(
        node_embeddings, nnodes, landmarks, nlandmarks,
        node_cluster, reg_scaled, dist, alpha, sinkhorn_niter,
        threshold=0.1, num_hashes=1, calc_norms=True, similarity=False):
    batch_size, _, _ = node_embeddings[0].shape

    (pairs_batch_idx,
        nodes_left, nodes_right,
        dists_idx_left, dists_idx_right) = get_pair_indices_matched_clusters(
                nnodes, landmarks, nlandmarks, node_cluster, reg_scaled,
                dist, alpha, sinkhorn_niter, threshold, num_hashes)

    # Fine scale
    if similarity:
        dists = dist.pairwise_similarity(node_embeddings[0][pairs_batch_idx, nodes_left],
                                         node_embeddings[1][pairs_batch_idx, nodes_right])
    else:
        dists = dist.pairwise_distance(node_embeddings[0][pairs_batch_idx, nodes_left],
                                       node_embeddings[1][pairs_batch_idx, nodes_right])
    # print(f"Clusters1 mean: {cluster_sizes1.float().mean()}")
    # print(f"Clusters2 mean: {cluster_sizes2.float().mean()}")
    # print(f"Clusters1 median: {cluster_sizes1.median()}")
    # print(f"Clusters2 median: {cluster_sizes2.median()}")
    # print(f"Clusters1 max: {cluster_sizes1.max()}")
    # print(f"Clusters2 max: {cluster_sizes2.max()}")
    # print(f"Dists shape: {dists.shape}")

    # Compute norm indices
    norms1_idx = repeat_blocks(nnodes[0], 1, continuous_indexing=False)
    norms1_batch_idx = torch.repeat_interleave(torch.arange(batch_size, dtype=torch.long, device=nnodes.device), nnodes[0])
    norms2_idx = repeat_blocks(nnodes[1], 1, continuous_indexing=False)
    norms2_batch_idx = torch.repeat_interleave(torch.arange(batch_size, dtype=torch.long, device=nnodes.device), nnodes[1])

    if calc_norms:
        # Calculate norms
        norms1 = node_embeddings[0][norms1_batch_idx, norms1_idx]
        norms2 = node_embeddings[1][norms2_batch_idx, norms2_idx]
        if alpha.shape[0] == 1:
            norms1 = alpha**2 * dist.norm(norms1)
            norms2 = alpha**2 * dist.norm(norms2)
        else:
            norms1 = dist.norm(alpha[None, :] * norms1)
            norms2 = dist.norm(alpha[None, :] * norms2)

        return ([dists, pairs_batch_idx, dists_idx_left, dists_idx_right],
                [norms1, norms1_batch_idx], [norms2, norms2_batch_idx])
    else:
        norms1 = None
        norms2 = None

        return ([dists, pairs_batch_idx, dists_idx_left, dists_idx_right],
                [norms1_batch_idx, norms1_idx], [norms2_batch_idx, norms2_idx],
                [nodes_left, nodes_right])


def calc_bp_dist_matrix_nystrom(node_embeddings, nnodes, landmarks, reg_scaled, dist, alpha,
                                calc_norms=True, similarity=False):
    batch_size, max_nodes, emb_size = node_embeddings[0].shape

    # Fill non-distance values in inner matrix with norms of node embeddings (scaled by learnable alpha)
    if calc_norms:
        if alpha.shape[0] == 1:
            norms1 = alpha**2 * dist.norm(node_embeddings[0])
            norms2 = alpha**2 * dist.norm(node_embeddings[1])
        else:
            norms1 = dist.norm(alpha[None, :] * node_embeddings[0])
            norms2 = dist.norm(alpha[None, :] * node_embeddings[1])

        # Fill non-norms with zeros, then with very large number (infinity causes backprop errors)
        mask_n1 = (torch.arange(max_nodes, dtype=torch.int64,
                                device=nnodes.device).expand_as(norms1)
                   >= nnodes[0, :, None])
        mask_n2 = (torch.arange(max_nodes, dtype=torch.int64,
                                device=nnodes.device).expand_as(norms2)
                   >= nnodes[1, :, None])
        norms1 = norms1.masked_fill(mask_n1, 1e20)
        norms2 = norms2.masked_fill(mask_n2, 1e20)
    else:
        norms1 = norms2 = None

    if similarity:
        sim_1a = dist.csim(node_embeddings[0], landmarks)
        sim_aa = dist.csim(landmarks, landmarks)
        sim_a2 = dist.csim(landmarks, node_embeddings[1])

        sim_aa = sim_aa / reg_scaled[:, None, None]

        sim_aa_inv, sign_aa = loginvexp(sim_aa, use_double=True)

        sim_aa_a2, sign_a2 = logsumexp_signed_signed(
                sim_aa_inv[:, :, :, None] + (sim_a2 / reg_scaled[:, None, None])[:, None, :, :],
                sign_aa[:, :, :, None], dim=2)

        # This mask gives everything that is not a real calculated distance
        mask_sim_1a = (torch.arange(max_nodes, dtype=torch.int64,
                                    device=nnodes.device)[:, None].expand_as(sim_1a)
                       >= nnodes[0, :, None, None])
        mask_sim_a2 = (torch.arange(max_nodes, dtype=torch.int64,
                                    device=nnodes.device).expand_as(sim_aa_a2)
                       >= nnodes[1, :, None, None])

        # Set all non-distances to very large negative number (infinity causes backprop errors)
        sim_1a = sim_1a.masked_fill(mask_sim_1a, -1e20)
        sim_aa_a2 = sim_aa_a2.masked_fill(mask_sim_a2, -1e20)
        sign_a2 = sign_a2.masked_fill(mask_sim_a2, 1)

        return [sim_1a, sim_aa_a2, sign_a2], norms1, norms2
    else:
        dist_1a = dist.cdist(node_embeddings[0], landmarks)
        dist_aa = dist.cdist(landmarks, landmarks)
        dist_a2 = dist.cdist(landmarks, node_embeddings[1])

        dist_aa = dist_aa / reg_scaled[:, None, None]

        sim_aa_inv, sign_aa = loginvexp(-dist_aa, use_double=True)

        sim_a2, sign_a2 = logsumexp_signed_signed(
                sim_aa_inv[:, :, :, None] - (dist_a2 / reg_scaled[:, None, None])[:, None, :, :],
                sign_aa[:, :, :, None], dim=2)

        # Optional way of computing sim_a2
        # sim_a2 = -dist_a2 / reg_scaled[:, None, None]
        # sim_a2_offset = sim_a2.max(dim=-2, keepdim=True).values
        # sim_a2_exp = (sim_a2 - sim_a2_offset).double().exp()
        # sim_aa = -dist_aa / reg_scaled[:, None, None]
        # sim_aa_exp = sim_aa.double().exp()  # Not sure if we really need double
        # sim_a2_exp = torch.solve(sim_a2_exp, sim_aa_exp).solution
        # sim_a2, sign_a2 = log_signed(sim_a2_exp)
        # sim_a2, sign_a2 = sim_a2.float(), sign_a2.float()  # Maybe threshold negative signs like in loginvexp
        # sim_a2 = sim_a2 + sim_a2_offset

        # This mask gives everything that is not a real calculated distance
        mask_dist_1a = (torch.arange(max_nodes, dtype=torch.int64,
                                    device=nnodes.device)[:, None].expand_as(dist_1a)
                        >= nnodes[0, :, None, None])
        mask_sim_a2 = (torch.arange(max_nodes, dtype=torch.int64,
                                    device=nnodes.device).expand_as(dist_a2)
                       >= nnodes[1, :, None, None])

        # Set all non-distances to very large negative number (infinity causes backprop errors)
        dist_1a = dist_1a.masked_fill(mask_dist_1a, 1e20)
        sim_a2 = sim_a2.masked_fill(mask_sim_a2, -1e20)
        sign_a2 = sign_a2.masked_fill(mask_sim_a2, 1)

        # approx_sim = torch.log((-dist_1a / reg_scaled[:, None, None]).exp() @ (sign_a2 * sim_a2.exp()))
        # approx_sim_double = torch.log((-dist_1a.double() / reg_scaled[:, None, None]).exp() @ (sign_a2 * sim_a2.double().exp()))
        # approx_sim2 = torch.log((-dist_1a / reg_scaled[:, None, None]).exp()
        #         @ (sign_aa * sim_aa.exp())
        #         @ (-dist_a2 / reg_scaled[:, None, None]).exp())
        # exact_sim = -dist.cdist(node_embeddings[0], node_embeddings[1]) / reg_scaled[:, None, None]

        return [dist_1a, sim_a2, sign_a2], norms1, norms2


def merge_dist_matrices(dist_mat_decomposed, dist_mat_sparse, dist_idx,
                        norms1_idx, norms2_idx, reg_scaled, similarity=False):
    [sim_dist_1a, sim_a2, sign_a2] = dist_mat_decomposed
    [sim_dists_exact, pairs_batch_idx, dists_idx_left, dists_idx_right] = dist_mat_sparse
    [nodes_left, nodes_right] = dist_idx
    [norms1_batch_idx, norms1_idx] = norms1_idx
    [norms2_batch_idx, norms2_idx] = norms2_idx

    # This is quadratic if Nyström and sparse use the same landmarks.
    # l1 * n * n/l2, l1 constant, l2~n^(2/3) -> l1 * n^(4/3)
    if similarity:
        sim_approx, sign_approx = logsumexp_signed_signed(
                sim_a2[pairs_batch_idx, :, nodes_right]
                + (sim_dist_1a / reg_scaled[:, None, None])[pairs_batch_idx, nodes_left],
                sign_a2[pairs_batch_idx, :, nodes_right], dim=-1)
    else:
        sim_approx, sign_approx = logsumexp_signed_signed(
                sim_a2[pairs_batch_idx, :, nodes_right]
                - (sim_dist_1a / reg_scaled[:, None, None])[pairs_batch_idx, nodes_left],
                sign_a2[pairs_batch_idx, :, nodes_right], dim=-1)

    return [sim_dist_1a, sim_a2, sign_a2, sim_dists_exact, sim_approx, sign_approx,
            pairs_batch_idx, dists_idx_left, dists_idx_right, nodes_left, nodes_right,
            norms1_batch_idx, norms1_idx, norms2_batch_idx, norms2_idx]


def compute_distmatrix(node_embeddings, num_nodes, dist_mat_len, sparse, nystrom,
                       reg_scaled, sinkhorn_niter, dist, alpha,
                       bp_dist_matrix=True, sparse_batching=False):
    batch_size, max_nodes, emb_size = node_embeddings[0].shape

    dist_idx = None
    norms1 = None
    norms2 = None
    nll_mask = None
    separate = (sparse is not None) and (sparse['threshold'] < 1)

    if sparse_batching:
        # Here 2 node_embeddings tensors with sum(num_nodes) x emb_size
        # Generate indices as if we always had dist_mat_len nodes.
        dist_idx = get_matrix_stack_idx(dist_mat_len)
        dist_matrix = calc_dist_matrix(node_embeddings, dist_idx,
                                       num_nodes, dist_mat_len, dist, alpha)
    else:
        # Here 2 node_embeddings tensors with batch_size x max_num_nodes x emb_size
        if sparse and nystrom:
            landmarks_sparse, node_cluster, sparse['centroids'] = calc_landmarks(
                    sparse['method'], node_embeddings, num_nodes,
                    sparse['nlandmarks'], reg_scaled, dist=dist,
                    nhashes=sparse['nhashes'], return_clusters=True,
                    centroids=sparse['centroids'], separate=separate)
            if separate:
                landmarks_sparse = landmarks_sparse.reshape(2, -1, sparse['nlandmarks'], emb_size)
            node_cluster = [cl.reshape(-1, max_nodes) for cl in node_cluster]
            dist_matrix_ms, norms1_idx, norms2_idx, dist_idx = calc_sparse_dist(
                    node_embeddings, num_nodes, landmarks_sparse, sparse['nlandmarks'], node_cluster,
                    reg_scaled=reg_scaled, dist=dist, alpha=alpha, sinkhorn_niter=sinkhorn_niter,
                    threshold=sparse['threshold'], num_hashes=sparse['nhashes'], calc_norms=False)
            if ((nystrom['method'] == sparse['method'])
                    and (nystrom['nlandmarks'] == sparse['nlandmarks'])):
                landmarks_nystrom = landmarks_sparse
            else:
                landmarks_nystrom, nystrom['centroids'] = calc_landmarks(
                        nystrom['method'], node_embeddings, num_nodes,
                        nystrom['nlandmarks'], reg_scaled, dist=dist,
                        return_clusters=False, centroids=nystrom['centroids'])
            landmarks_nystrom = landmarks_nystrom.reshape(-1, nystrom['nlandmarks'], emb_size)
            dist_matrix_nys, norms1, norms2 = calc_bp_dist_matrix_nystrom(
                    node_embeddings, num_nodes, landmarks_nystrom,
                    reg_scaled=reg_scaled, dist=dist, alpha=alpha)
            dist_matrix = merge_dist_matrices(dist_matrix_nys, dist_matrix_ms,
                                              dist_idx, norms1_idx, norms2_idx, reg_scaled)
        elif sparse:
            landmarks, node_cluster, sparse['centroids'] = calc_landmarks(
                    sparse['method'], node_embeddings, num_nodes,
                    sparse['nlandmarks'], reg_scaled, dist=dist,
                    nhashes=sparse['nhashes'], return_clusters=True,
                    centroids=sparse['centroids'], separate=separate)
            if separate:
                landmarks = landmarks.reshape(2, -1, sparse['nlandmarks'], emb_size)
            node_cluster = [cl.reshape(-1, max_nodes) for cl in node_cluster]
            dist_matrix, norms1, norms2 = calc_sparse_dist(
                    node_embeddings, num_nodes, landmarks, sparse['nlandmarks'], node_cluster,
                    reg_scaled=reg_scaled, dist=dist, alpha=alpha, sinkhorn_niter=sinkhorn_niter,
                    threshold=sparse['threshold'], num_hashes=sparse['nhashes'])
        elif nystrom:
            landmarks, nystrom['centroids'] = calc_landmarks(
                    nystrom['method'], node_embeddings, num_nodes,
                    nystrom['nlandmarks'], reg_scaled, dist=dist,
                    return_clusters=False, centroids=nystrom['centroids'])
            landmarks = landmarks.reshape(-1, nystrom['nlandmarks'], emb_size)
            dist_matrix, norms1, norms2 = calc_bp_dist_matrix_nystrom(
                    node_embeddings, num_nodes, landmarks,
                    reg_scaled=reg_scaled, dist=dist, alpha=alpha)
        elif bp_dist_matrix:
            dist_matrix, nll_mask = calc_bp_dist_matrix_padded(node_embeddings, num_nodes, dist, alpha, diag=True)
        else:
            # dist_matrix, nll_mask = calc_dist_matrix_padded(node_embeddings, num_nodes, dist, alpha)
            dist_matrix, nll_mask = calc_dist_matrix_padded_rect(node_embeddings, num_nodes, dist)
    return DistMatrix(dist_matrix, dist_mat_len, num_nodes, reg_scaled, dist_idx, norms1, norms2), nll_mask
