from collections import namedtuple
import torch

DraftOutput = namedtuple('DraftOutput', [
    'sequences', 
    'draft_probs', 
    'draft_past_key_values'
    ])
VerifyOutput = namedtuple('VerifyOutput', [
    'sequences', 
    'target_past_key_values', 
    'draft_past_key_values', 
    'accept_count'
    ])
GeneratorOutput = namedtuple('SpeculativeGeneratorOutput', [
    'sequences',
    'acceptance_rate',
    'token_rate',
    'avg_generation_time',
    'avg_verification_time',
    'num_invocations',
    'total_time'
])

def log(t, eps=1e-20):
    return torch.log(t.clamp(min=eps))

def gumbel_noise(noise):
    return -log(-log(noise))

def gumbel_sample(logits, noise=None, dim=-1):
    if noise == None:
        noise = torch.zeros_like(logits).uniform_(0, 1)
    return (logits + gumbel_noise(noise)).argmax(dim=dim)

class LogitsProcessor:
    def __init__(self, temperature, top_p, top_k):
        self.temperature = temperature
        self.top_p = top_p
        self.top_k = top_k

    def __call__(self, logits):
        val, ind = torch.topk(logits, self.top_k, sorted=False, dim=-1)
        probs = torch.full_like(logits, float('-inf'))
        probs.scatter_(-1, ind, val)
        return probs / self.temperature
