import torch
import numpy as np
import torch.nn.functional as F
import os
import math


def top_p_logits(logits, top_p=None):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
    mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
    return logits

def sample_tokens(logits, top_p=None,neg_entropy=False):
    if top_p is not None and top_p < 1:
        logits = top_p_logits(logits, top_p)
    probs = torch.softmax(logits, dim=-1)

    if neg_entropy:
        epsilon = 1e-10
        log_probs = torch.log(probs + epsilon)
        confidence = torch.sum(probs * log_probs, dim=-1)

    return confidence, probs

def add_gumbel_noise(logits, temperature):
    '''
    The Gumbel max is a method for sampling categorical distributions.
    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
    Thus, we use float64.
    '''
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def get_num_transfer_tokens(mask_index, steps, hot_start=0):
    mask_num = mask_index.sum(dim=1, keepdim=True)
    effective_steps = steps - hot_start
    if effective_steps <= 0:
        return torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64)
    base = mask_num // effective_steps
    remainder = mask_num % effective_steps
    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64)
    num_transfer_tokens[:, hot_start:] = base
    for i in range(mask_num.size(0)):
        if remainder[i] > 0:
            num_transfer_tokens[i, steps-remainder[i]:] += 1
    return num_transfer_tokens

def compute_KL(probs_pre, probs):
    epsilon = 1e-8
    probs_pre_safe = torch.clamp(probs_pre, min=epsilon)
    probs_safe = torch.clamp(probs, min=epsilon)
    
    log_ratio = torch.log(probs_pre_safe / probs_safe)
    kl_divs = probs_pre_safe * log_ratio
    
    kl_per_row = torch.sum(kl_divs, dim=1)  
    
    res = torch.mean(kl_per_row)
    
    return res


@ torch.no_grad()
def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
             cfg_scale=0., remasking='low_confidence', mask_id=126336,eos_token_id=126081):
    '''
    Args:
        model: Mask predictor.
        prompt: A tensor of shape (1, L).
        steps: Sampling steps, less than or equal to gen_length.
        gen_length: Generated answer length.
        block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
        temperature: Categorical distribution sampling temperature.
        cfg_scale: Unsupervised classifier-free guidance scale.
        remasking: Remasking strategy. 'low_confidence' or 'random'.
        mask_id: The toke id of [MASK] is 126336.
    '''
    x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()

    prompt_index = (x != mask_id)
    
    batch_size = prompt.shape[0]
    seq_length = prompt.shape[1] + gen_length
    decode_range = torch.ones(batch_size, seq_length, device=x.device, dtype=torch.bool)

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps = steps // num_blocks

    base_rate = float(os.environ['BASE_RATE'])
    rate_flex = float(os.environ['RATE_FLEX'])
    hot_start = int(os.environ['COLD_START'])
    dean_topp = float(os.environ['DEAN_TOPP'])
    diffthreshold = float(os.environ['DIFFTHRESHOLD'])
    klthreshold = float(os.environ['KLTHRESHOLD'])
    probs_full = None 
    full_entropy = None
    all_embeddings = None
    current_steps = 0
    probs = None
    probs_pre = None
    probs_ppre = None
    cur_KL = -1
    pre_KL = -1

    i = 0
    switch2decode = 0
    switch2stop = 0

    for num_block in range(num_blocks):
        block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
        current_steps = steps
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, current_steps, 0)
        i = 0 
        while i< current_steps:
            
            if eos_token_id is not None:
                eos_positions = (x == eos_token_id)
                for b in range(batch_size):
                    eos_indices = torch.where(eos_positions[b])[0]
                    if len(eos_indices) > 0:
                        first_eos_idx = eos_indices[0].item()
                      
                        decode_range[b, first_eos_idx:] = False
            elif switch2stop == 1:
                for b in range(batch_size):
                    decode_range[b, i+1:] = False
            
          
            mask_index = (x == mask_id) & decode_range
          
            batch_complete = []
            for b in range(batch_size):
                if eos_token_id is not None and (x[b] == eos_token_id).any():
                    first_eos_idx = torch.where(x[b] == eos_token_id)[0][0].item()
                    
                    if not (x[b, :first_eos_idx] == mask_id).any():
                        batch_complete.append(b)
           
            if batch_complete:
                remaining_mask = torch.zeros_like(mask_index)
                for b in batch_complete:
                    remaining_mask[b] = (x[b] == mask_id)
                
                if remaining_mask.any():
                    logits = model(x).logits
                    remaining_logits = logits[remaining_mask]
                    remaining_tokens = torch.argmax(remaining_logits, dim=-1) 
                    x[remaining_mask] = remaining_tokens
                    
                   
                    for b in batch_complete:
                        mask_index[b] = False

            if not mask_index.any():
                return x


            rate=base_rate
            if cfg_scale > 0.:
                un_x = x.clone()
                un_x[prompt_index] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                logits = model(x_).logits
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                if base_rate != 1 and probs_full is not None:
                    inputs_embeds = model.module.model.transformer.wte(x)
                    vocal_size = probs_full.shape[-1]
                    if all_embeddings is None:
                        all_embeddings = model.module.model.transformer.wte(torch.arange(vocal_size, device=probs_full.device))
                    mix_embed = torch.matmul(probs_full, all_embeddings)
                    inputs_embeds[mask_index] = \
                        (rate+full_entropy[mask_index])*inputs_embeds[mask_index]+ \
                        (1-rate-full_entropy[mask_index])*mix_embed[mask_index]
                    logits = model(inputs_embeds = inputs_embeds).logits
                else:
                    logits = model(x).logits

            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1) 

            if remasking == 'low_confidence':
                p = F.softmax(logits, dim=-1)
                x0_p = torch.squeeze(
                    torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) 
                if dean_topp != 0:
                    full_entropy, probs_full = sample_tokens(logits,top_p=dean_topp, neg_entropy=True)
            elif remasking == 'random':
                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
            else:
                raise NotImplementedError(remasking)
            if full_entropy != None:
                if rate_flex == 0:
                        full_entropy=0*full_entropy
                else:
                    full_entropy=-full_entropy/(math.log(probs_full.size(-1)))*rate_flex
                full_entropy=full_entropy.unsqueeze(-1)
            
            if probs_ppre!=None:
                cur_KL = compute_KL(probs_pre[mask_index], probs)
                if switch2decode == 0:
                    pre_KL = compute_KL(probs_ppre[mask_index], probs)
                    if abs(cur_KL-pre_KL)<diffthreshold or i>hot_start:
                        switch2decode = 1 
                if cur_KL < klthreshold:
                    switch2stop = 1 
                probs_ppre = probs_pre
                probs_pre = probs_full
                
            else:
                probs_ppre = probs_pre
                probs_pre = probs_full

            
            if switch2decode == 0:
                continue

            x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf

            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, torch.tensor(-np.inf, device=x0.device))

            for j in range(confidence.shape[0]):
                num_tokens = num_transfer_tokens[j, i].item()
                if num_tokens > 0:
                    _, select_indices = torch.topk(confidence[j], k=num_tokens)
                    x[j, select_indices] = x0[j, select_indices]
    return x
