from jax import numpy as jnp, nn as jax_nn
import jax
import distrax


def sample_tok(logits, rng, temperature=0):
    """
    Sample one word, given its logits
    Args:
        logits: torch.FloatTensor
            Tensor(Batch_size, vocab_size)
        rng: jax.array
            Jax random number
        temperature: float
            constant which flattens or peaks the distribution
    """
    # no need to normalize if we need only the peak of the dist
    if temperature == 0:
        return jnp.argmax(logits, axis=-1)
    # temperature: low -> sample the original (i.e elem with high prob)
    #              high -> make a wider (closer to uniform) dist as temp increases -> explore more
    p = jax_nn.softmax(logits / temperature, axis=-1)
    cat_dist = distrax.Categorical(probs=p)
    # TODO: use public method for sampling
    sampl = cat_dist._sample_n(rng, n=1)
    return sampl[0, :]


def nucleus_sampling(logits, rng, top_p, temperature=1):
    """
    Implement nucleus sampling
    Args:
        logits: torch.FloatTensor
            Tensor(Batch_size, vocab_size)
        rng: jax.array
            Jax random number
        top_p: float
            treshold for cummulative distirbution
        temperature: float
            constant which flattens or peaks the distribution
    Returns:
        sampled token
    """
    if temperature == 0:
        temperature += 1e-6
    samp_probs = jax_nn.softmax(logits / temperature, axis=-1)

    sorted_indexes = jnp.argsort(samp_probs, axis=-1)[:, ::-1]
    sorted_probs = jnp.take_along_axis(samp_probs, sorted_indexes, axis=-1)
    cumulative_probs = jnp.cumsum(sorted_probs, axis=-1)
    # Remove tokens with cumulative probability above the threshold - i.e small probabilities
    sorted_indices_to_remove = cumulative_probs > top_p
    # Shift the indices to the right to keep also the first token above the threshold
    sorted_indices_to_remove = sorted_indices_to_remove.at[:, 1:].set(
        jnp.copy(sorted_indices_to_remove[:, :-1])
    )
    sorted_indices_to_remove = sorted_indices_to_remove.at[:, 0].set(False)

    sorted_samp_probs = jnp.copy(sorted_probs)
    sorted_samp_probs = sorted_samp_probs.at[sorted_indices_to_remove].set(0)

    # sample from the sorted prob
    cat_dist = distrax.Categorical(probs=sorted_samp_probs)
    # TODO: use public method for sampling
    sampl = cat_dist._sample_n(rng, n=1)[0, :]
    # reindex from sorted to real index
    batch_idx = jnp.arange(logits.shape[0])
    sampl = sorted_indexes[batch_idx, sampl]
    return sampl


def topK_sampling(logits, rng, top_k, temperature):
    """
    Select top K logits, renormalise and sample
    """
    p = jax_nn.softmax(logits / temperature, axis=-1)
    p, topk_idxs = jax.lax.top_k(p, k=top_k)
    # renormalise
    p = jnp.einsum("BP,B->BP", p, 1 / (p.sum(axis=-1)))
    cat_dist = distrax.Categorical(probs=p)
    # TODO: use public method for sampling
    sampl = cat_dist._sample_n(rng, n=1)[:, 0]
    # reindex from sorted to real index
    batch_idx = jnp.arange(logits.shape[0])
    sampl = topk_idxs[batch_idx, sampl]
    return sampl


#### Testing ##########
def test_nucleus():
    rng = jax.random.PRNGKey(seed=0)
    logits = jnp.array([[1, 2, 3, 4, 5], [10, 100, 100, 2, 30]])
    res = nucleus_sampling(logits=logits, rng=rng, p=0.5, temperature=1)
    jax.debug.print("Samples: {x}, Expected  [[4], [1 or 2]]", x=res.tolist())


def test_topK_sampling():

    rng = jax.random.PRNGKey(seed=1)
    logits = jnp.array([[1, 2, 3, 4, 5], [10, 100, 100, 2, 30]])
    res = topK_sampling(logits, rng, top_k=2, temperature=1)
    jax.debug.print("Samples: {x}, Expected  [[3 ir 4], [1 or 2]]", x=res.tolist())


if __name__ == "__main__":
    # test_nucleus()
    test_topK_sampling()
