#! /usr/bin/env python3
# coding=utf-8


import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


def top_p_decoding(token_logits, top_p, min_token_keep):
    """
    if the probability accumulated to index i is already > p, remove > i indices
    :param token_logits: <tensor> (1, vocab_size) softmax output
    :param top_p: <float> threshold p
    :param min_token_keep: <int> min num of tokens kept
    :return: filtered probs
    """

    sorted_logits, sorted_indices = torch.sort(token_logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    # if cumulated to index i the p is already > p, remove > i indices
    remove_mask = torch.gt(cumulative_probs, top_p).detach()

    # remove_mask = [0, 0, 0, ... , 1, 1, 1]. 0 for keeping, 1 for remove
    if min_token_keep > 1:
        remove_mask[..., :min_token_keep - 1] = 0

    # in worst case, at least one token to keep, remove_mask = [0, 1, 1, ... 1]
    remove_mask[..., 1:] = remove_mask[..., :-1].clone()
    remove_mask[..., 0] = 0
    remove_mask = torch.gt(remove_mask.clone().detach(), 0)

    # fill in mask, and set -inf to masked tokens
    indices_to_remove = remove_mask.scatter(dim=1, index=sorted_indices, src=remove_mask)
    token_logits[indices_to_remove] = -float("Inf")

    # pick the token
    token_probs_norm = F.softmax(token_logits, dim=-1)
    token_id = torch.multinomial(token_probs_norm, num_samples=1)
    return token_id


def top_k_decoding(token_logits, top_k, min_token_keep):
    """
    keep the top k tokens that have highest probability
    :param token_logits: <tensor> (1, vocab_size) softmax output
    :param top_k: <float> threshold k
    :param min_token_keep: <int> min num of tokens kept
    :return: filtered probs
    """
    # top k filtering with respect to min_token_keep
    safe_top_k = min(max(top_k, min_token_keep), token_logits.size(-1))
    remove_mask = torch.le(token_logits, torch.topk(token_logits, safe_top_k)[0][..., -1, None])
    token_logits[remove_mask] = -float("Inf")

    # pick the token
    token_probs_norm = F.softmax(token_logits, dim=-1)
    token_id = torch.multinomial(token_probs_norm, num_samples=1)
    return token_id


def greedy_decoding(token_logits):
    """
    simple greedy decoding
    :param token_logits:
    :return:
    """
    probs = F.softmax(token_logits, dim=-1)
    token_id = torch.argmax(probs, dim=-1)[None, :]
    return token_id


