import torch

# with replacement


def wr_gen(logits_q, temperature, num_candidates):
    """
    input:
    logits_q - shape (batch_size, vocab_size)
    temperature - scalar
    num_candidates - scalar
    output:
    drafts - shape (batch_size, num_candidates)
    """

    if temperature == 0:
        return torch.argmax(logits_q, dim=-1).unsqueeze(-1).expand(-1, num_candidates)
    q = torch.softmax(logits_q / temperature, dim=-1)
    return torch.multinomial(q, num_candidates, replacement=True)


# without replacement


def wor_gen(logits_q, temperature, num_candidates):
    if temperature == 0:
        return torch.topk(logits_q, num_candidates, dim=-1).indices
    q = torch.softmax(logits_q / temperature, dim=-1)
    return torch.multinomial(q, num_candidates, replacement=False)


# GCSpS


def GCSpS_gen(logits_q, temperature, num_candidates):
    if num_candidates == 1:
        return wr_gen(logits_q, temperature, num_candidates)
    if temperature == 0:
        return torch.topk(logits_q, num_candidates, dim=-1).indices
    topkm1 = torch.topk(logits_q, num_candidates - 1, dim=-1).indices
    logits_qn = logits_q.clone()
    logits_qn.scatter_(-1, topkm1, -torch.inf)
    qn = torch.softmax(logits_qn / temperature, dim=-1)
    return torch.cat([topkm1, torch.multinomial(qn, 1, replacement=True)], dim=-1)
