from functools import partial
import torch
import torch.nn as nn
import numpy as np
from utils import scatter, kernels, distances, repeat_blocks
from utils.lsh import get_angular_buckets


def get_matching_embeddings(X, dist):
    dists = dist.cdist(X, X)
    matches = torch.isclose(dists, X.new_zeros(1), atol=1e-5)
    batch, X1, X2 = torch.where(matches)
    match_nonself = torch.where(X1 < X2)
    return batch[match_nonself], X1[match_nonself], X2[match_nonself]


def get_kmeanspp_sample(X, C, mask, dist):
    D_ij = dist.cdist(X, C)  # (batch_size, 2*max_nodes, kclusters)
    min_dists = D_ij.min(dim=2).values  # (batch_size, 2*max_nodes) Dist to nearest cluster
    min_dists.masked_fill_(mask, 0)
    probs = min_dists**2
    probs[(probs.sum(-1) == 0), :] = 1
    rows = torch.multinomial(probs, 1).squeeze()
    batch_idx = torch.arange(X.shape[0], dtype=torch.long, device=X.device)
    return X[batch_idx, rows]


def add_kmeanspp_sample(X, C, dist_XC, isample, mask, dist):
    dist_XC[:, :, isample - 1] = dist.cdist(X, C[:, isample - 1][:, None]).squeeze(dim=-1)  # (batch_size, 2*max_nodes, 1)
    min_dists = dist_XC[:, :, :isample].min(dim=2).values  # (batch_size, 2*max_nodes) Dist to nearest cluster
    min_dists.masked_fill_(mask, 0)
    probs = min_dists**2
    probs[(probs.sum(-1) == 0), :] = 1
    rows = torch.multinomial(probs, 1, replacement=True).squeeze()
    batch_idx = torch.arange(X.shape[0], dtype=torch.long, device=X.device)
    C[:, isample] = X[batch_idx, rows]


@torch.no_grad()
def get_cluster_assignments(X: torch.Tensor, C: torch.Tensor,
                            mask_value: int, dist: distances.Distance,
                            nnodes: torch.Tensor = None, mask: torch.Tensor = None):

    if mask is None:
        batch_size, max_nodes, _ = X.shape
        mask = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, -1)
                >= nnodes[:, None])

    D_ij = dist.cdist(X, C)  # (batch_size, 2*max_nodes, kclusters)
    node_cluster = D_ij.argmin(dim=2).long().squeeze(dim=-1)  # (batch_size, 2*max_nodes) Nearest cluster
    node_cluster.masked_fill_(mask, mask_value)
    return node_cluster


@torch.no_grad()
def kmeanspp_sampling_cat(X, ksamples, dist, nnodes=None, mask_padding=None):
    batch_size, max_nodes2, emb_size = X.shape
    max_nodes = max_nodes2 // 2
    if nnodes is not None:
        assert torch.all(nnodes.sum(0) >= ksamples)

    if mask_padding is None:
        assert nnodes is not None
        mask_padding = torch.zeros((batch_size, 2 * max_nodes), dtype=torch.bool, device=X.device)
        mask_padding[:, :max_nodes] = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, -1)
                                       >= nnodes[0, :, None])
        mask_padding[:, max_nodes:] = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, -1)
                                       >= nnodes[1, :, None])

    C = torch.empty((batch_size, ksamples, emb_size), device=X.device)

    batch_idx = torch.arange(batch_size, dtype=torch.long, device=X.device)[:, None]
    mask_nodes = ~mask_padding
    mask_nodes[mask_nodes.sum(-1) == 0, :] = True
    C[batch_idx, 0] = X[batch_idx, torch.multinomial(mask_nodes.float(), 1, replacement=True)]

    dist_XC = torch.empty((batch_size, max_nodes2, ksamples - 1), device=X.device)
    for i in range(1, ksamples):
        add_kmeanspp_sample(X, C, dist_XC, i, mask_padding, dist)
    return C


@torch.no_grad()
def kmeans_cat_padded(X, nnodes, kclusters, niter, dist, kmeanspp_init=False):
    batch_size, max_nodes2, _ = X.shape
    max_nodes = max_nodes2 // 2
    assert max_nodes >= kclusters
    assert nnodes.dim() == 2

    mask_padding = torch.zeros((batch_size, 2 * max_nodes), dtype=torch.bool, device=X.device)
    mask_padding[:, :max_nodes] = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, -1)
                                   >= nnodes[0, :, None])
    mask_padding[:, max_nodes:] = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, -1)
                                   >= nnodes[1, :, None])

    return kmeans_masked(X, mask_padding, kclusters, niter, dist, kmeanspp_init=kmeanspp_init)


@torch.no_grad()
def kmeans_padded(X, nnodes, kclusters, niter, dist, kmeanspp_init=False):
    batch_size, max_nodes, _ = X.shape
    assert max_nodes >= kclusters
    assert nnodes.dim() == 1

    mask_padding = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, -1)
                    >= nnodes[:, None])

    return kmeans_masked(X, mask_padding, kclusters, niter, dist, kmeanspp_init=kmeanspp_init)


@torch.no_grad()
def kmeans_masked(X, mask, kclusters, niter, dist, kmeanspp_init=False, empty_threshold=None):
    batch_size = X.shape[0]
    if empty_threshold is None:
        empty_threshold = kclusters

    if kmeanspp_init:
        C = kmeanspp_sampling_cat(X, kclusters, dist, mask_padding=mask)
    else:
        # Simple random initialization
        node_idx = torch.multinomial((~mask).float(), kclusters, replacement=False)
        batch_idx = torch.arange(batch_size, dtype=torch.long, device=X.device)[:, None].expand_as(node_idx)
        C = X[batch_idx, node_idx]

        # Resample to prevent matching centroids
        batch_matching, idx_matching1, _ = get_matching_embeddings(C, dist=dist)
        while len(batch_matching) > 0:
            C[batch_matching, idx_matching1] = get_kmeanspp_sample(X[batch_matching], C[batch_matching],
                                                                   mask[batch_matching], dist=dist)
            batch_matching, idx_matching1, _ = get_matching_embeddings(C, dist=dist)

    for i in range(niter):
        node_cluster = get_cluster_assignments(X, C, mask_value=kclusters, dist=dist, mask=mask)
        Ncl = scatter(X.new_ones(1).expand_as(node_cluster), node_cluster,
                      dim=1, dim_size=kclusters + 1, reduce='sum')[:, :kclusters]  # (batch_size, kclusters) Class weights
        empty_clusters = (Ncl == 0.) & (Ncl.sum(-1) > empty_threshold)[:, None]

        # Reassign centroids of empty clusters
        if torch.any(empty_clusters):
            batch, C_idx = torch.where(empty_clusters)
            C[batch, C_idx] = get_kmeanspp_sample(X[batch], C[batch], mask[batch], dist=dist)

            node_cluster = get_cluster_assignments(X, C, mask_value=kclusters, dist=dist, mask=mask)
            Ncl = scatter(X.new_ones(1).expand_as(node_cluster), node_cluster,
                          dim=1, dim_size=kclusters + 1, reduce='sum')[:, :kclusters]  # (batch_size, kclusters) Class weights

        C = scatter(X, node_cluster[:, :, None],
                    dim=1, dim_size=kclusters + 1, reduce='sum')[:, :kclusters, :] / (Ncl[:, :, None] + 1e-10)

    return C, node_cluster


def get_other_sample(X, C, p, dist, min_dist=1e-5):
    D_ij = dist.cdist(X, C)  # (batch_size, 2*max_nodes, kclusters)
    mask = torch.any(D_ij < min_dist, dim=-1)
    p.masked_fill_(mask, 0)
    rows = torch.multinomial(p, 1).squeeze()
    batch_idx = torch.arange(X.shape[0], dtype=torch.long, device=X.device)
    return X[batch_idx, rows]


@torch.no_grad()
def proportional_sampling(X, p, ksamples, dist):
    batch_size, _, _ = X.shape

    landmark_idx = torch.multinomial(p, ksamples, replacement=False)
    batch_idx = torch.arange(batch_size, dtype=torch.long, device=X.device)[:, None].expand_as(landmark_idx)
    samples = X[batch_idx, landmark_idx]

    # Resample to prevent matching samples
    batch_matching, idx_matching1, _ = get_matching_embeddings(samples, dist=dist)
    while len(batch_matching) > 0:
        samples[batch_matching, idx_matching1] = get_other_sample(X[batch_matching], samples[batch_matching],
                                                                  p[batch_matching], dist=dist)
        batch_matching, idx_matching1, _ = get_matching_embeddings(samples, dist=dist)
    return samples


@torch.no_grad()
def uniform_sampling_cat(X, nnodes, ksamples, dist):
    batch_size, max_nodes2, emb_size = X.shape
    max_nodes = max_nodes2 // 2
    assert torch.all(nnodes.sum(0) >= ksamples)

    mask_nodes = torch.zeros((batch_size, 2 * max_nodes), dtype=torch.bool, device=X.device)
    mask_nodes[:, :max_nodes] = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, -1)
                                 < nnodes[0, :, None])
    mask_nodes[:, max_nodes:] = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, -1)
                                 < nnodes[1, :, None])
    probs = mask_nodes.float()

    return proportional_sampling(X, probs, ksamples, dist)


@torch.no_grad()
def norm_sampling_cat(X, nnodes, ksamples, dist, norm_fn=partial(torch.norm, p=2, dim=-1)):
    batch_size, max_nodes2, emb_size = X.shape
    max_nodes = max_nodes2 // 2
    assert torch.all(nnodes.sum(0) >= ksamples)

    probs = norm_fn(X)

    mask_nodes = torch.zeros((batch_size, 2 * max_nodes), dtype=torch.bool, device=X.device)
    mask_nodes[:, :max_nodes] = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, -1)
                                 < nnodes[0, :, None])
    mask_nodes[:, max_nodes:] = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, -1)
                                 < nnodes[1, :, None])
    probs *= mask_nodes

    return proportional_sampling(X, probs, ksamples, dist)


@torch.no_grad()
def norm_choice_cat(X, nnodes, ksamples, norm_fn=partial(torch.norm, p=2, dim=-1)):
    batch_size, max_nodes2, emb_size = X.shape
    max_nodes = max_nodes2 // 2
    assert torch.all(nnodes.sum(0) >= ksamples)

    weights = norm_fn(X)

    mask_nodes = torch.zeros((batch_size, 2 * max_nodes), dtype=torch.bool, device=X.device)
    mask_nodes[:, :max_nodes] = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, -1)
                                 < nnodes[0, :, None])
    mask_nodes[:, max_nodes:] = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, -1)
                                 < nnodes[1, :, None])
    weights *= mask_nodes

    landmark_idx = torch.topk(weights, k=ksamples, dim=-1, sorted=False).indices
    batch_idx = torch.arange(batch_size, dtype=torch.long, device=X.device)[:, None].expand_as(landmark_idx)
    samples = X[batch_idx, landmark_idx]

    return samples


def kmeans_hier_inner(node_embeddings_cat, node_cluster_outer, inner_clusters, outer_clusters, niter, dist):
    batch_size, _, emb_size = node_embeddings_cat.shape

    cluster_idx, node_idx = torch.sort(node_cluster_outer, dim=-1)
    batch_idx = torch.arange(batch_size, dtype=torch.long, device=node_idx.device)[:, None].expand_as(node_idx)
    real_node = (cluster_idx < outer_clusters)
    cluster_idx = (batch_idx * outer_clusters + cluster_idx)[real_node]
    batch_idx = batch_idx[real_node]
    node_idx = node_idx[real_node]

    clusters_total = batch_size * outer_clusters
    cluster_sizes = scatter(cluster_idx.new_ones(1).expand_as(cluster_idx), cluster_idx,
                            dim_size=clusters_total, dim=-1, reduce='sum')
    max_cluster = cluster_sizes.max()
    mask_padding = torch.zeros((clusters_total, max_cluster), dtype=torch.bool, device=node_idx.device)
    mask_padding = (torch.arange(max_cluster, dtype=torch.long, device=node_idx.device).expand(clusters_total, -1)
                    >= cluster_sizes[:, None])

    embeddings_inner = node_embeddings_cat.new_zeros((clusters_total, max_cluster, emb_size))
    node_idx_inner = repeat_blocks(cluster_sizes, 1, continuous_indexing=False)
    embeddings_inner[cluster_idx, node_idx_inner] = node_embeddings_cat[batch_idx, node_idx]

    landmarks, node_cluster_cat_inner = kmeans_masked(
            embeddings_inner, mask_padding, kclusters=inner_clusters, niter=niter, dist=dist,
            kmeanspp_init=True, empty_threshold=2 * inner_clusters)

    node_cluster_inner = node_cluster_outer.new_zeros(node_cluster_outer.shape)
    node_cluster_inner[batch_idx, node_idx] = node_cluster_cat_inner[cluster_idx, node_idx_inner]

    node_cluster_cat = node_cluster_outer * inner_clusters + node_cluster_inner
    return landmarks, node_cluster_cat


@torch.no_grad()
def calc_landmarks(method, node_embeddings, nnodes, nlandmarks, reg_scaled, dist,
                   nhashes=1, return_clusters=False, centroids=None, separate=False):
    batch_size, max_nodes, emb_size = node_embeddings[0].shape

    if separate:
        assert method in ['kmeans', 'kmeans_hier']
        node_embeddings_cat = torch.cat(node_embeddings, dim=0)
        nnodes = nnodes.flatten()
    else:
        node_embeddings_cat = torch.cat(node_embeddings, dim=1)

    if method == 'sampling_uniform':
        assert nhashes == 1
        landmarks = uniform_sampling_cat(node_embeddings_cat, nnodes, nlandmarks, dist=dist)
        node_cluster_cat = None
    elif method == 'sampling_kmeanspp':
        assert nhashes == 1
        landmarks = kmeanspp_sampling_cat(node_embeddings_cat, nlandmarks, dist=dist, nnodes=nnodes)
        node_cluster_cat = None
    elif method == 'sampling_rls':
        assert nhashes == 1
        from utils.recursive_rls import recursive_rls_sampling_cat
        kernel = kernels.Laplacian(dist, lmbda=1 / reg_scaled)
        landmarks = recursive_rls_sampling_cat(node_embeddings_cat, nnodes, nlandmarks, ksamples_inner=5, kernel=kernel)
        node_cluster_cat = None
    elif method == 'learned':
        assert nhashes == 1
        if centroids is None:
            centroids = nn.Parameter(torch.empty(nlandmarks, emb_size, device=nnodes.device))
            landmarks_cluster, _ = kmeans_cat_padded(
                node_embeddings_cat, nnodes, kclusters=nlandmarks, niter=10, dist=dist, kmeanspp_init=True)
            with torch.autograd.no_grad():
                centroids.copy_(landmarks_cluster[0])
        landmarks = centroids[None, :, :].repeat(batch_size, 1, 1)
        node_cluster_cat = None
    elif method == 'kmeans_hier':
        if nlandmarks >= 1000:
            kclusters = [10, 10]
        else:
            kclusters = [10]
        kclusters_above = [int(cs) for cs in np.cumprod(kclusters)]
        assert all([nlandmarks % cs == 0 for cs in kclusters_above])
        kclusters.append(nlandmarks // kclusters_above[-1])

        if separate:
            _, node_cluster_cat = kmeans_padded(
                    node_embeddings_cat, nnodes, kclusters=kclusters[0], niter=10, dist=dist, kmeanspp_init=True)
        else:
            _, node_cluster_cat = kmeans_cat_padded(
                    node_embeddings_cat, nnodes, kclusters=kclusters[0], niter=10, dist=dist, kmeanspp_init=True)

        for level in range(1, len(kclusters)):
            landmarks, node_cluster_cat = kmeans_hier_inner(
                    node_embeddings_cat, node_cluster_cat,
                    kclusters[level], kclusters_above[level - 1], niter=10, dist=dist)
        node_cluster_cat.unsqueeze_(1)
        if separate:
            landmarks = landmarks.reshape(len(node_embeddings), nhashes, nlandmarks, emb_size)
    elif method == 'kmeans':
        landmarkss = []
        node_cluster_cats = []
        for _ in range(nhashes):
            if separate:
                landmarks, node_cluster_cat = kmeans_padded(
                        node_embeddings_cat, nnodes, kclusters=nlandmarks, niter=10, dist=dist, kmeanspp_init=True)
            else:
                landmarks, node_cluster_cat = kmeans_cat_padded(
                        node_embeddings_cat, nnodes, kclusters=nlandmarks, niter=10, dist=dist, kmeanspp_init=True)
            landmarkss.append(landmarks)
            node_cluster_cats.append(node_cluster_cat)
        landmarks = torch.stack(landmarkss, dim=1)
        node_cluster_cat = torch.stack(node_cluster_cats, dim=1)
        del landmarkss, node_cluster_cats
    elif method == 'lsh':
        landmarks = None
        node_cluster_cat = get_angular_buckets(node_embeddings_cat, nnodes, nbuckets=nlandmarks, nhashes=nhashes)
    else:
        raise ValueError(f"Unknown landmarks/clustering method: '{method}'")

    if return_clusters:
        if node_cluster_cat is None:
            assert nhashes == 1
            if separate:
                landmarks = landmarks.reshape(len(node_embeddings), nhashes, nlandmarks, emb_size)
                node_cluster = [get_cluster_assignments(node_embeddings[i], landmarks[i],
                                                        mask_value=nlandmarks, dist=dist, nnodes=nnodes[i])
                                for i in range(2)]
            else:
                node_cluster = [get_cluster_assignments(node_embeddings[i], landmarks,
                                                        mask_value=nlandmarks, dist=dist, nnodes=nnodes[i])
                                for i in range(2)]
        else:
            if separate:
                node_cluster = list(torch.split(node_cluster_cat, batch_size, dim=0))
            else:
                node_cluster = list(torch.split(node_cluster_cat, max_nodes, dim=-1))
        return landmarks, node_cluster, centroids
    else:
        return landmarks, centroids
