import numpy as np
import torch

def sampling(logit, p=0.5, t=1.0):
    logit = logit.squeeze()
    probs = torch.softmax(logit/t,dim=-1)
    probs=probs.cpu().detach().numpy()
    cur_word = nucleus(probs, p=p)
    return cur_word

def nucleus(probs, p):
    probs /= (sum(probs) + 1e-5)
    sorted_probs = np.sort(probs)[::-1]
    sorted_index = np.argsort(probs)[::-1]
    cusum_sorted_probs = np.cumsum(sorted_probs)
    p = min(p, 1.0)
    after_threshold = cusum_sorted_probs >= p
    if sum(after_threshold) > 0:
        last_index = np.where(after_threshold)[0][0] + 1
        candi_index = sorted_index[:last_index]
    else:
        candi_index = sorted_index[0:1]
    candi_probs = [probs[i] for i in candi_index]
    candi_probs /= sum(candi_probs)
    word = np.random.choice(candi_index, size=1, p=candi_probs)[0]
    return word