import torch
import numpy as np
import torch.nn.functional as F


def add_gumbel_noise(logits, temperature):
    '''
    The Gumbel max is a method for sampling categorical distributions.
    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
    Thus, we use float64.
    '''
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def get_num_transfer_tokens(mask_index, steps):
    '''
    In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
    Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
    the expected number of tokens transitioned at each step should be consistent.

    This function is designed to precompute the number of tokens that need to be transitioned at each step.
    '''
    mask_num = mask_index.sum(dim=1, keepdim=True)

    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base

    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1

    return num_transfer_tokens



@torch.no_grad()
def vanilla(input_ids,attention_mask,model,gen_length,block_length,steps,temperature,cfg_scale,
             mask_id,remasking,**kwargs):
    batch_size, prompt_length = input_ids.shape
    x = torch.full(
            (batch_size, prompt_length + gen_length),
            mask_id,
            dtype=torch.long,
            device=model.device,
        )
    x[:, :prompt_length] = input_ids
    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length
    prompt_index = (x != mask_id)

    assert steps % num_blocks == 0
    steps = steps // num_blocks
    nfe = 0

    for num_block in range(num_blocks):
        block_mask_index = (x[:, prompt_length + num_block * block_length: prompt_length + (num_block + 1) * block_length:] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
        for i in range(steps):
            nfe += 1
            mask_index = (x == mask_id)
            if cfg_scale > 0.:
                un_x = x.clone()
                un_x[prompt_index] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                logits = model(x_).logits
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model(x).logits

            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1) # b, l

            if remasking == 'low_confidence':
                p = F.softmax(logits, dim=-1)
                x0_p = torch.squeeze(
                    torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
            elif remasking == 'random':
                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
            else:
                raise NotImplementedError(remasking)

            x0_p[:, prompt_length + (num_block + 1) * block_length:] = -np.inf

            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -np.inf)

            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            for j in range(confidence.shape[0]):
                _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
                transfer_index[j, select_index] = True
            x[transfer_index] = x0[transfer_index]


    return x, nfe

@torch.no_grad()
def vanilla_with_cache(input_ids,attention_mask,model,gen_length,block_length,temperature,cfg_scale,
             mask_id,remasking,refresh_interval,**kwargs):
    batch_size, prompt_length = input_ids.shape
    x = torch.full(
            (batch_size, prompt_length + gen_length),
            mask_id,
            dtype=torch.long,
            device=model.device,
        )
    x[:, :prompt_length] = input_ids
    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length
    nfe = 0
    past_key_values = None
    
    for num_block in range(num_blocks):
        block_start_index = prompt_length + num_block * block_length
        block_end_index = block_start_index + block_length
        if nfe % refresh_interval == 0:
            output = model(x, use_cache=True)
            past_key_values = output.past_key_values
            logits = output.logits
            logits = logits[:,block_start_index:block_end_index]
            logits_with_noise = add_gumbel_noise(logits, temperature)
            prob = F.softmax(logits, dim=-1)
            predict_token_ids = torch.argmax(logits_with_noise, dim=-1)
            confidence = torch.gather(   # [batch_size, seq_len]
                prob,dim=-1,index=predict_token_ids.unsqueeze(-1)).squeeze(-1)
            _, relative_index = torch.topk(confidence,k=1,dim=-1)
            relative_index = relative_index.squeeze(-1)
            absolute_index = relative_index + block_start_index
            x[:,absolute_index] = predict_token_ids[:,relative_index]
            refresh_cache = False
            nfe += 1

        replace_position = torch.zeros_like(x, dtype=torch.bool)
        replace_position[:, block_start_index:block_end_index] = 1

        while True:
            assert past_key_values is not None
            current_block = x[:, block_start_index:block_end_index]
            mask_remain = int((current_block == mask_id).sum().item())
            if mask_remain == 0:
                break
            mask = (x == mask_id )[:,block_start_index:block_end_index]
            if nfe % refresh_interval == 0:
                output = model(x,use_cache=True)
                logits = output.logits
                past_key_values = output.past_key_values
                logits = logits[:,block_start_index:block_end_index]      
            else:
                logits = model(current_block,
                            past_key_values=past_key_values,
                            use_cache=True,
                            replace_position=replace_position).logits
            nfe += 1
            # logits shape [batch_size,block_length,vocab_size]
            logits_with_noise = add_gumbel_noise(logits, temperature)
            prob = F.softmax(logits, dim=-1)
            predict_token_ids = torch.argmax(logits_with_noise, dim=-1)
            confidence = torch.gather(   # [batch_size, seq_len]
                prob,dim=-1,index=predict_token_ids.unsqueeze(-1)).squeeze(-1)
            confidence = torch.where(mask, confidence, -torch.inf)
            _, relative_index = torch.topk(confidence,k=1,dim=-1)
            relative_index = relative_index.squeeze(-1)
            absolute_index = relative_index + block_start_index
            x[:,absolute_index] = predict_token_ids[:,relative_index]
    assert nfe == gen_length
    return x, nfe