import numpy as np
import torch

from utils.landmarks import get_matching_embeddings, get_kmeanspp_sample


@torch.no_grad()
def recursive_rls_sampling_cat(X, nnodes, ksamples: int, kernel, lmbda_0=0, ksamples_inner: int = None):
    '''
    Paper: https://arxiv.org/pdf/1605.07583.pdf
    Adapted from https://github.com/axelv/recursive-nystrom
    '''
    batch_size, max_nodes2, emb_size = X.shape
    max_nodes = max_nodes2 // 2
    nnodes2 = nnodes.sum(0)
    assert torch.all(nnodes2 >= ksamples)

    # We can reduce ksamples_inner for faster runtimes, as suggested in the paper (Sec. 5.2.1)
    if ksamples_inner is None:
        ksamples_inner = ksamples

    mask_padding = torch.zeros((batch_size, 2 * max_nodes), dtype=torch.bool, device=X.device)
    mask_padding[:, :max_nodes] = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, -1)
                                   >= nnodes[0, :, None])
    mask_padding[:, max_nodes:] = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, -1)
                                   >= nnodes[1, :, None])

    n_oversample = np.log(ksamples)
    k = np.ceil(ksamples / (4 * n_oversample)).astype(np.int)
    n_levels = int(np.ceil(np.log(max_nodes2 / ksamples) / np.log(2)))
    # Can't set padding prob to exactly 0. This should have only real nodes in the beginning
    # except for finite possibility of sampling 9999 due to float precision
    sampling_prob = ~mask_padding + torch.finfo(torch.float).tiny
    perm = torch.multinomial(sampling_prob, max_nodes2, replacement=False)

    batch_idx_full = torch.arange(batch_size, dtype=torch.long, device=X.device)[:, None].expand_as(perm)
    mask_padding_perm = mask_padding[batch_idx_full, perm]

    # set up sizes for recursive levels
    size_list = [max_nodes2]
    for lvl in range(1, n_levels+1):
        size_list += [int(np.ceil(size_list[lvl - 1] / 2))]

    # indices of points selected at previous level of recursion
    # at the base level it's just a uniform sample of ~ ksamples points

    # Make sure we only get real nodes (no padding) in first sample
    first_sample_mask = ~(mask_padding_perm[:, :size_list[-2]])

    sample = torch.multinomial(first_sample_mask.float(), size_list[-1], replacement=False)
    batch_indices = torch.arange(batch_size, dtype=torch.long, device=X.device)[:, None].expand(-1, size_list[-1])
    indices = perm[batch_indices, sample]
    weights = torch.ones_like(sample, device=X.device)

    # we need the diagonal of the whole kernel matrix, so compute upfront
    k_diag = kernel.pairwise_similarity(X, X)

    # Main recursion, unrolled for efficiency
    for lvl in reversed(range(n_levels)):
        sample_size = sample.shape[1]
        # indices of current uniform sample
        current_indices = perm[:, :size_list[lvl]]
        # build sampled kernel

        # all rows and sampled columns
        batch_current_indices = torch.arange(batch_size, dtype=torch.long, device=X.device)[:, None].expand_as(current_indices)
        KS = kernel.csim(X[batch_current_indices, current_indices], X[batch_indices, indices])
        SKS = KS[batch_indices, sample, :]  # sampled rows and sampled columns

        # optimal lambda for taking O(k log(k)) samples
        if k >= sample_size:
            # for the rare chance we take less than k samples in a round
            lmbda = torch.tensor(10e-6, device=X.device)
            # don't set to exactly 0 to avoid stability issues
        else:
            # eigenvalues equal roughly the number of points per cluster, maybe this should scale with n?
            # can be interpreted as the zoom level
            # TODO: Starting with PyTorch 1.5 we can use torch.lobpcg for this.
            lmbda = ((torch.sum(torch.diagonal(SKS, dim1=1, dim2=2) * (weights ** 2), dim=1)
                      - torch.sum(torch.symeig(SKS * weights[:, :, None] * weights[:, None, :])
                                  .eigenvalues[:, sample_size-k:sample_size], dim=1))
                     / k)
        lmbda = torch.clamp(lmbda, min=lmbda_0 * sample_size)
        if torch.any(lmbda == lmbda_0 * sample_size):
            print(f"Set lambda to {lmbda}.")

        # compute and sample by lambda ridge leverage scores
        R = torch.inverse(SKS + torch.diag_embed(lmbda[:, None] / weights**2))
        R = KS @ R
        # R = torch.lstsq(KS, SKS + torch.diag_embed(lmbda[:, None] / weights**2))  # Not batched

        # Helps avoid numerical issues, unnecessary in theory
        leverage_score_inner = torch.clamp(k_diag[batch_current_indices, current_indices] - torch.sum(R * KS, dim=2),
                                           min=0)
        if lvl != 0:
            # We probably don't need n_oversample or clamp, since we sample a fixed number
            leverage_score = torch.clamp(n_oversample * leverage_score_inner / lmbda[:, None], max=1)
        else:
            leverage_score = torch.clamp(leverage_score_inner / lmbda[:, None], max=1)

        # To enable batching we sample exactly ksamples instead of independently sampling each column
        # in the intermediate levels as well and mask out the padding for sampling
        leverage_score.masked_fill_(mask_padding_perm[:, :size_list[lvl]], 0)
        # sample = np.random.choice(max_nodes2, size=ksamples, replace=False, p=p.numpy())
        if lvl != 0:
            sample = torch.multinomial(leverage_score, ksamples_inner, replacement=False)
        else:
            sample = torch.multinomial(leverage_score, ksamples, replacement=False)

        if (lvl == n_levels - 1) or (lvl == 0):
            batch_indices = batch_indices[:, :1].expand_as(sample)

        if lvl != 0:
            weights = torch.sqrt(1. / leverage_score[batch_indices, sample])

        indices = perm[batch_indices, sample]

    samples = X[batch_indices, indices]

    # Resample to prevent matching samples
    batch_matching, idx_matching1, _ = get_matching_embeddings(samples, dist=kernel.dist)
    while len(batch_matching) > 0:
        samples[batch_matching, idx_matching1] = get_kmeanspp_sample(X[batch_matching], samples[batch_matching],
                                                                     mask_padding[batch_matching], dist=kernel.dist)
        batch_matching, idx_matching1, _ = get_matching_embeddings(samples, dist=kernel.dist)

    return samples


# Small check to test if the algorithms output makes sense
if __name__ == "__main__":
    from utils import distances, kernels

    n1 = 100
    n2 = 5000
    n3 = 4900
    n = torch.tensor([n1, n2, n3])
    torch.manual_seed(10)
    X = torch.cat(
            [torch.distributions.MultivariateNormal(torch.tensor([50, 10], dtype=torch.float),
                                                    torch.eye(2, dtype=torch.float)).sample((n1,)),
             torch.distributions.MultivariateNormal(torch.tensor([-70, -70], dtype=torch.float),
                                                    torch.eye(2, dtype=torch.float)).sample((n2,)),
             torch.distributions.MultivariateNormal(torch.tensor([90, -40], dtype=torch.float),
                                                    torch.eye(2, dtype=torch.float)).sample((n3,))],
            dim=0).cuda()
    y = torch.cat([torch.ones((n1,)) * 1,
                   torch.ones((n2,)) * 2,
                   torch.ones((n3,)) * 3]).cuda()
    idx = np.arange(X.shape[0])
    np.random.shuffle(idx)
    X = X[idx]
    y = y[idx]

    batch_size = 200
    num_nodes = 2000
    X[5000 + num_nodes:, :] = 2000
    X = X[None, :, :].expand([batch_size, -1, -1])
    landmarks = recursive_rls_sampling_cat(
            X,
            nnodes=torch.full([2, batch_size], num_nodes, dtype=torch.long, device=X.device),
            ksamples=10, kernel=kernels.Laplacian(distances.PNorm(p=2), lmbda=0.1))

    indices = torch.where(torch.cdist(X.double(), landmarks.double()) == 0)[1]
    y_total = y[indices]

    u, c = torch.unique(y_total, return_counts=True)
    _, n_chosen = torch.unique(torch.cat([y[:num_nodes], y[5000:5000 + num_nodes]]), return_counts=True)
    print("Real balance:", n_chosen * 1.0 / n_chosen.sum())
    print("RLS balance:", c * 1.0 / c.sum())
