# The code is adapted from repo: https://github.com/lucidrains/vector-quantize-pytorch

from functools import partial

import torch
import torch.nn.functional as F
import torch.distributed as distributed
from torch import nn, einsum
from torch.cuda.amp import autocast

from einops import rearrange, repeat, reduce, pack, unpack


def exists(val):
    return val is not None


def default(val, d):
    return val if exists(val) else d


def noop(*args, **kwargs):
    pass


def l2norm(t):
    return F.normalize(t, p=2, dim=-1)


def cdist(x, y):
    x2 = reduce(x ** 2, 'b n d -> b n', 'sum')
    y2 = reduce(y ** 2, 'b n d -> b n', 'sum')
    xy = einsum('b i d, b j d -> b i j', x, y) * -2
    return (rearrange(x2, 'b i -> b i 1') + rearrange(y2, 'b j -> b 1 j') + xy).clamp(min = 0).sqrt()


def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))


def pack_one(t, pattern):
    return pack([t], pattern)


def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]


def uniform_init(*shape):
    t = torch.empty(shape)
    nn.init.kaiming_uniform_(t)
    return t


def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))


def gumbel_sample(logits, temperature=1., stochastic=False, straight_through=False, reinmax=False, dim=-1, training=True):
    dtype, size = logits.dtype, logits.shape[dim]

    if training and stochastic and temperature > 0:
        sampling_logits = (logits / temperature) + gumbel_noise(logits)
    else:
        sampling_logits = logits

    ind = sampling_logits.argmax(dim=dim)
    one_hot = F.one_hot(ind, size).type(dtype)

    assert not (reinmax and not straight_through), 'reinmax can only be turned on if using straight through gumbel softmax'

    if not straight_through or temperature <= 0. or not training:
        return ind, one_hot

    # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612 algorithm 2
    if reinmax:
        π0 = logits.softmax(dim = dim)
        π1 = (one_hot + (logits / temperature).softmax(dim = dim)) / 2
        π1 = ((log(π1) - logits).detach() + logits).softmax(dim = 1)
        π2 = 2 * π1 - 0.5 * π0
        one_hot = π2 - π2.detach() + one_hot
    else:
        π1 = (logits / temperature).softmax(dim=dim)
        one_hot = one_hot + π1 - π1.detach()

    return ind, one_hot


def sample_vectors(samples, num):
    num_samples, device = samples.shape[0], samples.device
    if num_samples >= num:
        indices = torch.randperm(num_samples, device = device)[:num]
    else:
        indices = torch.randint(0, num_samples, (num,), device = device)

    return samples[indices]


def batched_sample_vectors(samples, num):
    return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim = 0)], dim = 0)


def pad_shape(shape, size, dim = 0):
    return [size if i == dim else s for i, s in enumerate(shape)]


def sample_multinomial(total_count, probs):
    device = probs.device
    probs = probs.cpu()

    total_count = probs.new_full((), total_count)
    remainder = probs.new_ones(())
    sample = torch.empty_like(probs, dtype = torch.long)

    for i, p in enumerate(probs):
        s = torch.binomial(total_count, p / remainder)
        sample[i] = s
        total_count -= s
        remainder -= p

    return sample.to(device)


def all_gather_sizes(x, dim):
    size = torch.tensor(x.shape[dim], dtype = torch.long, device = x.device)
    all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())]
    distributed.all_gather(all_sizes, size)
    return torch.stack(all_sizes)


def all_gather_variably_sized(x, sizes, dim = 0):
    rank = distributed.get_rank()
    all_x = []

    for i, size in enumerate(sizes):
        t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim))
        distributed.broadcast(t, src = i, async_op = True)
        all_x.append(t)

    distributed.barrier()
    return all_x


def sample_vectors_distributed(local_samples, num):
    local_samples = rearrange(local_samples, '1 ... -> ...')

    rank = distributed.get_rank()
    all_num_samples = all_gather_sizes(local_samples, dim = 0)

    if rank == 0:
        samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
    else:
        samples_per_rank = torch.empty_like(all_num_samples)

    distributed.broadcast(samples_per_rank, src = 0)
    samples_per_rank = samples_per_rank.tolist()

    local_samples = sample_vectors(local_samples, samples_per_rank[rank])
    all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim = 0)
    out = torch.cat(all_samples, dim = 0)

    return rearrange(out, '... -> 1 ...')


def batched_bincount(x, *, minlength):
    batch, dtype, device = x.shape[0], x.dtype, x.device
    target = torch.zeros(batch, minlength, dtype = dtype, device = device)
    values = torch.ones_like(x)
    target.scatter_add_(-1, x, values)
    return target


def kmeans(
    samples,
    num_clusters,
    num_iters = 10,
    use_cosine_sim = False,
    sample_fn = batched_sample_vectors,
    all_reduce_fn = noop
):
    num_codebooks, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device

    means = sample_fn(samples, num_clusters)

    for _ in range(num_iters):
        if use_cosine_sim:
            dists = samples @ rearrange(means, 'h n d -> h d n')
        else:
            dists = -cdist(samples, means)

        buckets = torch.argmax(dists, dim = -1)
        bins = batched_bincount(buckets, minlength = num_clusters)
        all_reduce_fn(bins)

        zero_mask = bins == 0
        bins_min_clamped = bins.masked_fill(zero_mask, 1)

        new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype = dtype)

        new_means.scatter_add_(1, repeat(buckets, 'h n -> h n d', d = dim), samples)
        new_means = new_means / rearrange(bins_min_clamped, '... -> ... 1')
        all_reduce_fn(new_means)

        if use_cosine_sim:
            new_means = l2norm(new_means)

        means = torch.where(
            rearrange(zero_mask, '... -> ... 1'),
            means,
            new_means
        )

    return means, bins


def batched_embedding(indices, embeds):
    # indices -> (expert_num, heads_num, batch_size)
    # embeds -> (expert_num, heads_num, codebook_size, dim)
    indices = repeat(indices, 'e h b -> e h b d', d=embeds.shape[-1])
    return embeds.gather(2, indices)


def orthogonal_loss_fn(t):
    # eq (2) from https://arxiv.org/abs/2112.00384
    h, n = t.shape[:2]
    normed_codes = l2norm(t)
    cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes)
    return (cosine_sim ** 2).sum() / (h * n ** 2) - (1 / n)


# def orthogonal_loss_fn(t):  # t -> (expert_num, heads_num, codebook_size, dim)
#     codebook_num = t.shape[0]
#     if codebook_num == 1:
#         return 0.
#     mean_codes = t.mean(dim=1).mean(dim=1)
#     normed_codes = l2norm(mean_codes)
#     cosine_sim = einsum('i d, j d -> i j', normed_codes, normed_codes)
#     triu_i, triu_j = torch.triu_indices(codebook_num, codebook_num, offset=1)
#     off_diag_sim = cosine_sim[triu_i, triu_j]
#     return (off_diag_sim ** 2).mean()


class CosineSimCodebook(nn.Module):
    def __init__(
        self,
        dim,
        codebook_size,
        num_codebooks,
        codebooks_heads,
        kmeans_init,
        kmeans_iters,
        sync_kmeans,
        use_ddp,
        sample_codebook_temp,
        gumbel_sample,
    ):
        super().__init__()
        if not kmeans_init:
            embed = l2norm(uniform_init(num_codebooks, codebooks_heads, codebook_size, dim))
        else:
            embed = torch.zeros(num_codebooks, codebooks_heads, codebook_size, dim)

        assert callable(gumbel_sample)
        self.gumbel_sample = gumbel_sample
        self.sample_codebook_temp = sample_codebook_temp

        self.codebook_size = codebook_size
        self.kmeans_iters = kmeans_iters
        self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
        self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop

        self.register_buffer('initted', torch.Tensor([not kmeans_init]))
        self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
        self.register_buffer('embed_avg', embed.clone())

        self.embed = nn.Parameter(embed)  # embed -> (expert_num, heads_num, codebook_size, dim)

    @torch.jit.ignore
    def init_embed_(self, data):
        embed, cluster_size = kmeans(
            data,
            self.codebook_size,
            self.kmeans_iters,
            use_cosine_sim = True,
            sample_fn = self.sample_fn,
            all_reduce_fn = self.kmeans_all_reduce_fn
        )

        embed_sum = embed * rearrange(cluster_size, '... -> ... 1')

        self.embed.data.copy_(embed)
        self.embed_avg.data.copy_(embed_sum)
        self.cluster_size.data.copy_(cluster_size)
        self.initted.data.copy_(torch.Tensor([True]))

    @autocast(enabled=False)
    def forward(self, x):
        if self.initted is None:
            self.init_embed_(x)

        dist = einsum('e h b d, e h c d -> e h b c', x, self.embed)

        embed_ind, embed_onehot = self.gumbel_sample(dist, dim=-1, temperature=self.sample_codebook_temp, training=self.training)

        if self.training:
            quantize = einsum('e h b c, e h c d -> e h b d', embed_onehot, self.embed)
        else:
            quantize = batched_embedding(embed_ind, self.embed)
        
        return quantize, embed_ind


class VectorQuantize(nn.Module):
    def __init__(
        self,
        dim,
        codebook_size,
        expert_num,
        heads,
        topk,
        kmeans_init = False,
        kmeans_iters = 10,
        sync_kmeans = True,
        stochastic_sample_codes = False,
        sample_codebook_temp = 1.,
        straight_through = False,
        reinmax = False,  # using reinmax for improved straight-through, assuming straight through helps at all
        sync_codebook = None,
        sync_update_v = 0.  # the v that controls optimistic vs pessimistic update for synchronous update rule
    ):
        super().__init__()
        self.dim = dim
        self.expert_num = expert_num
        self.heads = heads
        self.topk = topk

        self.gate = nn.Linear(dim, expert_num)

        codebook_input_dim = dim * heads
        requires_projection = codebook_input_dim != dim
        self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
        self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()

        self.sync_update_v = sync_update_v

        gumbel_sample_fn = partial(
            gumbel_sample,
            stochastic = stochastic_sample_codes,
            reinmax = reinmax,
            straight_through = straight_through
        )

        if not exists(sync_codebook):
            sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1

        self._codebook = CosineSimCodebook(
            dim = dim,
            num_codebooks = expert_num,
            codebook_size = codebook_size,
            codebooks_heads = heads,
            kmeans_init = kmeans_init,
            kmeans_iters = kmeans_iters,
            sync_kmeans = sync_kmeans,
            use_ddp = sync_codebook,
            sample_codebook_temp = sample_codebook_temp,
            gumbel_sample = gumbel_sample_fn
        )

        self.softmax = nn.Softmax(dim=1)

    @property
    def codebook(self):
        codebook = self._codebook.embed
        return codebook
    
    @codebook.setter
    def codebook(self, codes):
        self._codebook.embed.copy_(codes)

    def forward(self, x, field=None):
        origin_x = x  # origin_x -> (batch_size, dim)

        x = self.project_in(x)
        x = rearrange(x, 'b (h d) -> h b d', h=self.heads)
        x = repeat(x, 'h b d -> e h b d', e=self.expert_num)  # x -> (expert_num, heads_num, batch_size, dim)
        x = l2norm(x)

        quantize, embed_ind = self._codebook(x)  # quantize -> (expert_num, heads_num, batch_size, dim)

        # straight through
        quantize = x + (quantize - x).detach()

        # moe
        choose_expert = self.gate(origin_x)
        top_logits, top_indices = choose_expert.topk(self.topk, dim=1)
        probs = F.softmax(top_logits, dim=1)
        zeros = torch.zeros_like(choose_expert, requires_grad=True)
        gates = zeros.scatter(1, top_indices, probs)
        expert_weights = rearrange(gates, 'b e -> e b 1 1')
        quantize = einsum('e b m n, e h b d -> b h d', expert_weights, quantize)  # quantize -> (batch_size, heads_num, dim)
        
        quantize = rearrange(quantize, 'b h d -> b (h d)')
        quantize = self.project_out(quantize)

        # aggregate loss
        loss = torch.tensor([0.], device=x.device, requires_grad=self.training)

        if field is not None:
            if self.sync_update_v > 0.:
                quantize = quantize + self.sync_update_v * (quantize - quantize.detach())

            contrastive_loss = loss + self.contrastive_loss(quantize, origin_x)

            field_loss = loss + F.cross_entropy(choose_expert, field)

            loss = (contrastive_loss, field_loss)

        return quantize, loss

    def sim(self, z1, z2):
        z1 = nn.functional.normalize(z1)
        z2 = nn.functional.normalize(z2)

        return torch.mm(z1, z2.t())

    def contrastive_loss(self, z1, z2, tau=0.5):
        refl_sim = torch.exp(self.sim(z1, z1) / tau)
        between_sim = torch.exp(self.sim(z1, z2) / tau)
        z2_sim = torch.exp(self.sim(z2, z2) / tau)
        loss = -torch.log(between_sim.diag() / (refl_sim.sum(1) + between_sim.sum(1) + z2_sim.sum(1)))

        return loss.mean()
