import torch

def create_mixed_matrix(probs, k):
    """
    Create token mixing matrix.
    
    Args:
        probs (torch.Tensor): Token probabilities
        k (int): Number of tokens to mix
    """
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    top_k_mask = torch.arange(probs.shape[-1])
    top_k_mask = top_k_mask.expand(probs.shape[0], -1) # duplicate along dim=0 to match probs array
    top_k_mask = top_k_mask >= k # set to 1 past k elements
    probs_sort[top_k_mask] = 0.0 # 0 wherever mask = 1
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    restore_idx = torch.argsort(probs_idx, dim=-1)
    restored_probs = torch.gather(probs_sort, 1, restore_idx)
    return probs_sort, probs_idx, restored_probs

def log_prob_to_prob(log_probs, temp=1):
    """
    Convert log probabilities to probability distribution and normalize.
    Args:
        log_probs (torch.Tensor): Log probs (n_prompts, n_drafts, vocab_size)
    """
    # stability constant
    log_probs = log_probs + torch.max(log_probs, dim=-1, keepdim=True)[0]
    probs = torch.softmax(log_probs / temp, dim=-1)
    return probs

def decode(tokenizer, encoding):
    """
    Args:
        tokenizer (Any): Tokenizer
        encoding (torch.Tensor): Encoding
    Returns:
        decoding (str)
    """
    pad_locs = (encoding == -1).nonzero()
    if len(pad_locs > 0):
        encoding = encoding[:pad_locs[0].item()]
    return tokenizer.decode(encoding.to(torch.int32).tolist())

def print_gen(gens, logprobs, tokenizer, n_drafts, prompt_len, output_file):
    """
    Print out generations for debugging.
    Args:
        gens (n_prompts * n_drafts, seq_len): Generations to print
        logprobs (n_prompts * n_drafts): Log probs of each generation
        tokenizer (any): Tokenizer
        n_drafts (int): Number of drafts per prompt
        prompt_len (int): Number of tokens in prompt
    """
    n_prompts, n_drafts, seq_len = gens.shape
    gens = gens.reshape(-1, seq_len)
    logprobs = logprobs.flatten()
    count = 0
    for i in range(len(gens)):
        d = decode(tokenizer, gens[i])
        # first draft of this prompt
        if i % n_drafts == 0:
            count = 0
            print("---------------", file=output_file)
            prompt = decode(tokenizer, gens[i][:prompt_len])
            print(f"prompt: {prompt}", file=output_file)
        print(f"logprob: {logprobs[i]} {count}: {d}", file=output_file)
        count += 1
        
def print_probs(next_probs, tokenizer, output_file):
    """
    Print out next token options and probabilities.
    Args:
        next_probs (torch.Tensor): Next token probabilities (n_prompts, n_drafts, vocab_size)
        tokenizer (any): Tokenizer
    """
    print("\tReminder: At most first n_drafts from seq can be selected.", file=output_file)
    n_prompts, n_drafts, vocab_size = next_probs.shape
    for p_idx in range(n_prompts):
        print(f"\tPrompt {p_idx}:", file=output_file)
        for d_idx in range(n_drafts):
            next_token_probs, next_token_idx = next_probs[p_idx, d_idx].topk(n_drafts+2, dim=-1)
            print(f"\t\tTokens: {[tokenizer.decode([i.item()]) for i in next_token_idx]}", file=output_file)
            print(f"\t\tLog Probs: {torch.log(next_token_probs)}", file=output_file)
            print(f"\t\tProbs: {next_token_probs}", file=output_file)