import torch
from flashinfer.logits_processor import LogitsPipe, Temperature, Softmax, TopP, TopK
from flashinfer.sampling import (
    top_k_renorm_probs,
    top_p_renorm_probs,
    top_k_sampling_from_probs,
    top_p_sampling_from_probs,
    top_k_top_p_sampling_from_logits,
    sampling_from_logits,
    softmax,
)


def get_sampling_probs(logits, top_p, top_k, temperature):
    if len(logits.shape) == 2:
        bsz, voc_size = logits.shape
        seqlen = 1
    elif len(logits.shape) == 3:
        bsz, seqlen, voc_size = logits.shape
        logits = logits.reshape(-1, voc_size)
    else:
        raise ValueError(f"Given `logits` has an invalid shape : {logits.shape}")

    probs = softmax(logits, temperature=temperature)
    if top_p > 0: probs = top_p_renorm_probs(probs, top_p)
    if top_k > 0: probs = top_k_renorm_probs(probs, top_k)
    return probs.reshape(bsz, seqlen, voc_size)


def get_sampling_probs_v1(logits, top_p, top_k, temperature):
    if len(logits.shape) == 2:
        bsz, voc_size = logits.shape
        seqlen = 1
    elif len(logits.shape) == 3:
        bsz, seqlen, voc_size = logits.shape
        logits = logits.reshape(-1, voc_size)
    else:
        raise ValueError(f"Given `logits` has an invalid shape : {logits.shape}")
    
    logits_pipe = LogitsPipe([
        Temperature(),
        Softmax(),
        TopP(),
        TopK(),
    ])
    probs = logits_pipe(logits, temperature=temperature, top_p=top_p, top_k=top_k)
    return probs.reshape(bsz, seqlen, voc_size)
    

# Sample from logits with top-k, top-p, temperature, and suppress tokens
def sample(logits, top_p, top_k, temperature):
    if len(logits.shape) == 2:
        bsz, voc_size = logits.shape
        seqlen = 1
    elif len(logits.shape) == 3:
        bsz, seqlen, voc_size = logits.shape
        logits = logits.reshape(-1, voc_size)
    else:
        raise ValueError(f"Given `logits` has an invalid shape : {logits.shape}")

    if temperature == 0:
        return torch.argmax(logits, dim=-1).reshape(bsz, seqlen).long()
    
    if top_k > 0:
        if top_p > 0:
            # simultaneously apply top-k and top-p
            samples = top_k_top_p_sampling_from_logits(logits=logits/temperature, top_k=top_k, top_p=top_p)
        else:
            # apply top-k only
            probs = softmax(logits, temperature=temperature)
            samples = top_k_sampling_from_probs(probs=probs, top_k=top_k)
    else:
        if top_p > 0:
            # apply top-p only
            probs = softmax(logits, temperature=temperature)
            samples = top_p_sampling_from_probs(probs=probs, top_p=top_p)
        else:
            # no top-k or top-p
            samples = sampling_from_logits(logits=logits/temperature)
    
    samples = samples.reshape(bsz, seqlen).long()
    return samples