import torch
from flashinfer.sampling import (
    top_k_renorm_probs,
    top_p_renorm_probs,
    top_k_sampling_from_probs,
    top_p_sampling_from_probs,
    top_k_top_p_sampling_from_logits,
    sampling_from_logits,
)

def get_sampling_probs(logits, top_p, top_k, temperature):
    if len(logits.shape) == 2:
        bsz, voc_size = logits.shape
        seqlen = 1
    elif len(logits.shape) == 3:
        bsz, seqlen, voc_size = logits.shape
        logits = logits.reshape(-1, voc_size)
    else:
        raise ValueError(f"Given `logits` has an invalid shape : {logits.shape}")

    if temperature > 0:
        logits = logits / temperature
    probs = torch.softmax(logits, dim=-1)

    # Apply top-k first since it is the default behavior of flashinfer
    if top_k > 0:
        probs = top_k_renorm_probs(probs, top_k)
    if top_p > 0: 
        probs = top_p_renorm_probs(probs, top_p)

    probs = probs.reshape(bsz, seqlen, voc_size)
    return probs

# Sample from logits with top-k, top-p, temperature, and suppress tokens
def sample(logits, top_p, top_k, temperature):
    if len(logits.shape) == 2:
        bsz, voc_size = logits.shape
        seqlen = 1
    elif len(logits.shape) == 3:
        bsz, seqlen, voc_size = logits.shape
        logits = logits.reshape(-1, voc_size)
    else:
        raise ValueError(f"Given `logits` has an invalid shape : {logits.shape}")

    if temperature > 0:
        logits = logits / temperature
    else:
        return torch.argmax(logits, dim=-1).reshape(bsz, seqlen).long()
    
    if top_k > 0:
        if top_p > 0:
            # simultaneously apply top-k and top-p
            samples = top_k_top_p_sampling_from_logits(logits=logits, top_k=top_k, top_p=top_p)
        else:
            # apply top-k only
            probs = torch.softmax(logits, dim=-1)
            samples = top_k_sampling_from_probs(probs=probs, top_k=top_k)
    else:
        if top_p > 0:
            # apply top-p only
            probs = torch.softmax(logits, dim=-1)
            samples = top_p_sampling_from_probs(probs=probs, top_p=top_p)
        else:
            # no top-k or top-p
            samples = sampling_from_logits(logits=logits)
    
    samples = samples.reshape(bsz, seqlen).long()
    return samples