
import torch
from torch.nn import functional as F
from torch.distributions import Categorical

def decode_with_top_k(logits, k=50, temperature=1.0):
    # logits: [batch_size, seq_len, vocab_size]
    logits = logits / temperature 
    
    top_k_values, top_k_indices = torch.topk(logits, k=k, dim=-1)    
    
    fp32_values = top_k_values.float()
    probs = F.softmax(fp32_values, dim = -1)
    probs = probs / probs.sum(dim = -1, keepdim = True) 
    probs = probs.to(dtype = logits.dtype)

    sampler = Categorical(probs = probs, validate_args = False)

    next_token = sampler.sample()  
    next_token_expanded = next_token.unsqueeze(-1)
    selected_token_ids = top_k_indices.gather(-1, next_token_expanded).squeeze(-1)  # [batch_size, seq_len]
    selected_confidences = probs.gather(-1, next_token_expanded).squeeze(-1)  # [batch_size, seq_len]

    return selected_token_ids, selected_confidences

def decode_with_greedy(logits, temperature=1.0):
    # logits: [batch_size, seq_len, vocab_size]
    logits = logits / temperature 
    tokens = logits.argmax(dim=-1)
    confidences = F.softmax(logits, dim=-1).max(dim=-1).values
    return tokens, confidences

def decode_with_sample(logits, temperature=1.0):
    # logits: [batch_size, seq_len, vocab_size]
    logits = logits / temperature 
    probs = F.softmax(logits, dim=-1)
    tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
    confidences = probs.max(dim=-1).values
    return tokens, confidences

def decode_with_top_p(logits,  temperature=1.0, top_k=0, top_p=0.9):
        logits = logits / temperature
        probs = F.softmax(logits, dim=-1)

        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        # top-p 
        sorted_mask = cumulative_probs <= top_p
        sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
        sorted_mask[..., 0] = True  

        filtered_probs = sorted_probs * sorted_mask
        filtered_probs /= filtered_probs.sum(dim=-1, keepdim=True)

        # multinomial
        next_token = torch.multinomial(filtered_probs, num_samples=1)
        token_index = sorted_indices.gather(-1, next_token)

        confidence = filtered_probs.gather(-1, next_token)

        return token_index.squeeze(-1), confidence.squeeze(-1)

def logits_processor_decode(logits, logits_processor):
    processed_logits = logits.clone()
    for processor in logits_processor:
        processed_logits = processor(processed_logits)

    probs = F.softmax(processed_logits, dim=-1)
    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
    current_confidences = probs.gather(1, next_tokens.unsqueeze(1)).squeeze(1)
    print("next_tokens", next_tokens)
    print("current_confidences", current_confidences)

    return next_tokens, current_confidences