import torch
import numpy as np
import math
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from transformers.utils import ModelOutput
import transformers
import time
from dataclasses import dataclass
import logging
from datetime import datetime
from model.modeling_llada import LLaDAModelLM

# from utils.generation_utils import get_num_transfer_tokens, top_p_logits, top_k_logits, sample_tokens
from utils.alpha import BaseAlphaScheduler, LinearAlphaScheduler

def add_gumbel_noise(logits, temperature):
    '''
    The Gumbel max is a method for sampling categorical distributions.
    Using float16 for better performance while maintaining reasonable quality.
    '''
    if temperature == 0.:
        return logits  # Skip noise when temperature is 0
    
    # Use float32 instead of float64 for better performance
    # logits = logits.to(torch.float32)
    noise = torch.rand_like(logits, dtype=torch.bfloat16)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


@dataclass
class DreamModelOutput(ModelOutput):
    sequences: torch.LongTensor | None = None
    history: tuple[torch.FloatTensor] | None = None
    
def get_num_transfer_tokens(mask_index, steps):
    '''
    In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
    Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
    the expected number of tokens transitioned at each step should be consistent.

    This function is designed to precompute the number of tokens that need to be transitioned at each step.
    '''
    mask_num = mask_index.sum(dim=1, keepdim=True)

    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base

    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1

    return num_transfer_tokens
    
 
def get_num_transfer_tokens_remasking(mask_index, steps, parallel_tokens=1):
    '''
    For the cumulative remasking strategy:
    At step i, we keep the top (i+1)*parallel_tokens tokens with highest confidence.
    This means:
    - Step 0: keep parallel_tokens tokens
    - Step 1: keep 2*parallel_tokens tokens
    - Step 2: keep 3*parallel_tokens tokens
    - ...
    - Step n: keep (n+1)*parallel_tokens tokens

    This function returns the number of tokens to keep at each step.

    Args:
        mask_index: Boolean tensor indicating which positions are masked
        steps: Number of steps in the generation process
        parallel_tokens: Number of tokens to generate in parallel at each step
    '''
    mask_num = mask_index.sum(dim=1, keepdim=True)
    base = mask_num // steps
    remainder = mask_num % steps
    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base

    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1

    num_transfer_tokens = torch.cumsum(num_transfer_tokens, dim=1)
    return num_transfer_tokens


@torch.no_grad()
def generate(model, prompt, attention_mask=None, parallel_tokens=2, gen_length=128, block_length=128, temperature=0.,
             cfg_scale=0., remasking='low_confidence', mask_id=126336, logits_eos_inf=False, confidence_eos_eot_inf=False,
             tokenizer=None, log_to_file=False, do_likelihood_analysis=False):
    '''
    Args:
        model: Mask predictor.
        prompt: A tensor of shape (batch, L).
        gen_length: Generated answer length.
        block_length: Block length, less than or equal to gen_length. If less than gen_length, semi-autoregressive remasking.
        temperature: Categorical distribution sampling temperature.
        cfg_scale: Unsupervised classifier-free guidance scale.
        remasking: Remasking strategy. 'low_confidence' or 'random'.
        mask_id: The token id of [MASK].
        logits_eos_inf: Whether to set the logits of EOS token to -inf.
        confidence_eos_eot_inf: Whether to set the confidence of EOS and EoT token to -inf.
        do_likelihood_analysis: if True, returns (x, step_map) where step_map holds per-block arrays.
    '''
    steps = gen_length // parallel_tokens
    x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()

    if attention_mask is not None:
        attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1)

    prompt_index = (x != mask_id)
    assert gen_length % block_length == 0, "gen_length must be divisible by block_length"

    num_blocks = gen_length // block_length
    assert steps % num_blocks == 0, "steps must be divisible by num_blocks"
    
    steps = steps // num_blocks

    step_map = {}
    for num_block in range(num_blocks):
        block_start = prompt.shape[1] + num_block * block_length
        block_end = prompt.shape[1] + (num_block + 1) * block_length
        block_mask_index = (x[:, block_start:block_end] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)

        # Accumulators for this block (will become flat 1-D arrays for batch index 0)
        conf_acc = []
        tok_acc = []
        idx_acc = []

        for i in range(steps):
            mask_index = (x == mask_id)

            if cfg_scale > 0.:
                un_x = x.clone()
                un_x[prompt_index] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                if attention_mask is not None:
                    attention_mask_ = torch.cat([attention_mask, attention_mask], dim=0)
                else:
                    attention_mask_ = None
                logits = model(x_, attention_mask=attention_mask_).logits
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model(x, attention_mask=attention_mask).logits

            if logits_eos_inf:
                logits[:, :, 126081] = -torch.inf

            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1)  # b, l

            if confidence_eos_eot_inf:
                logits_with_noise[:, :, 126081] = logits[:, :, 126348] = -torch.inf

            if remasking == 'low_confidence':
                p = F.softmax(logits, dim=-1)
                x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)  # b, l
            elif remasking == 'random':
                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
            else:
                raise NotImplementedError(remasking)

            # Prevent selecting tokens outside the current block+prompt area
            x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf

            # only propose changes at masked positions
            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -np.inf)

            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            for j in range(confidence.shape[0]):
                k = int(num_transfer_tokens[j, i].item()) if isinstance(num_transfer_tokens[j, i], torch.Tensor) else int(num_transfer_tokens[j, i])
                if k <= 0:
                    continue
                vals, select_index = torch.topk(confidence[j], k=k)
                transfer_index[j, select_index] = True

                if do_likelihood_analysis and j == 0:
                    # Record only batch index 0 (to match downstream get_likelihood usage)
                    conf_acc.append(vals.detach().cpu().float().numpy())
                    tok_acc.append(x0[j, select_index].detach().cpu().numpy())
                    idx_acc.append(select_index.detach().cpu().numpy())

            # apply transferred tokens to x
            x[transfer_index] = x0[transfer_index]

        # After all steps for this block, collapse accumulators into flat 1-D arrays
        if do_likelihood_analysis:
            if len(conf_acc) == 0:
                step_map[num_block] = {
                    'conf': np.array([], dtype=np.float32),
                    'tok_list': np.array([], dtype=np.int64),
                    'idx_list': np.array([], dtype=np.int64)
                }
            else:
                # Each element of conf_acc/tok_acc/idx_acc is a 1-D array of length k for that step.
                step_map[num_block] = {
                    'conf': np.concatenate([np.asarray(a).ravel() for a in conf_acc], axis=0).astype(np.float32),
                    'tok_list': np.concatenate([np.asarray(a).ravel() for a in tok_acc], axis=0).astype(np.int64),
                    'idx_list': np.concatenate([np.asarray(a).ravel() for a in idx_acc], axis=0).astype(np.int64)
                }

    # Do not break previous code.
    if do_likelihood_analysis:
        return x, step_map

    return x

@ torch.no_grad()
def generate_with_prefix_cache(model, prompt, attention_mask=None, parallel_tokens=2, gen_length=128, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence', mask_id=126336, logits_eos_inf=False, confidence_eos_eot_inf=False,threshold=None, factor = None):
    '''
    Args:
        model: Mask predictor.
        prompt: A tensor of shape (1, L).
        steps: Sampling steps, less than or equal to gen_length.
        gen_length: Generated answer length.
        block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
        temperature: Categorical distribution sampling temperature.
        cfg_scale: Unsupervised classifier-free guidance scale.
        remasking: Remasking strategy. 'low_confidence' or 'random'.
        mask_id: The toke id of [MASK] is 126336.
    '''

    steps = gen_length // parallel_tokens
    
    x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps = steps // num_blocks

    nfe = 0

    if attention_mask is not None:
        attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1)

            
    for num_block in range(num_blocks):
        current_block_start = prompt.shape[1] + num_block * block_length
        current_block_end = current_block_start + block_length

        block_mask_index = (x[:, current_block_start:current_block_end] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)

        output = model(x, use_cache=True)
        past_key_values = output.past_key_values

        mask_index = (x == mask_id)
        mask_index[:, current_block_end:] = 0
        if factor is None:
            x0, transfer_index = get_transfer_index(output.logits, temperature, remasking, mask_index, x, num_transfer_tokens[:, 0] if threshold is None else None, threshold)
        else:
            x0, transfer_index = get_transfer_index_dynamic(output.logits, temperature, remasking, mask_index, x, None, factor)
        x[transfer_index] = x0[transfer_index]

        new_past_key_values = []
        for i in range(len(past_key_values)):
            new_past_key_values.append(())
            for j in range(len(past_key_values[i])):
                new_past_key_values[i] += (past_key_values[i][j][:, :, :current_block_start],)
        
        past_key_values = new_past_key_values
        nfe += 1
        
        i = 1
        while True:
            if (x[:, current_block_start:current_block_end] == mask_id).sum() == 0:
                break
            nfe += 1
            mask_index = (x[:, current_block_start:] == mask_id)
            mask_index[:, block_length:] = 0

            logits = model(x[:, current_block_start:], past_key_values=past_key_values, use_cache=True).logits

            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1) # b, l

            if factor is None:
                x0, transfer_index = get_transfer_index(logits, temperature, remasking, mask_index, 
                                                x[:, current_block_start:], num_transfer_tokens[:, i] if threshold is None else None, threshold)
            else:
                x0, transfer_index = get_transfer_index_dynamic(logits, temperature, remasking, mask_index, 
                                                x[:, current_block_start:], None, factor)
            x[:, current_block_start:][transfer_index] = x0[transfer_index]
            
            i += 1


    return x


@torch.no_grad()
@torch.compile(mode="max-autotune", fullgraph=True)
def generate_with_dual_cache(model, prompt, attention_mask=None, parallel_tokens=2, gen_length=128, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence', mask_id=126336, logits_eos_inf=False, confidence_eos_eot_inf=False,threshold=None, factor = None):
    steps = gen_length // parallel_tokens
    B = prompt.shape[0]
    Lp = int(prompt.shape[1])  # Python int, not Tensor
    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps_per_block = steps // num_blocks

    # x: (B, Lp + gen_length)
    x = torch.full((B, Lp + gen_length), mask_id, dtype=torch.long, device=model.device)
    x[:, :Lp] = prompt

    nfe = 0

    for nb in range(num_blocks):
        s = Lp + nb * block_length
        e = s + block_length

        # Masks/indices for the current block
        block_mask_index = (x[:, s:e] == mask_id)  # (B, block_length)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)  # (B, steps_per_block)

        # 1) Warm KV-cache on the full prefix once per block
        out_full = model(x, use_cache=True)
        past_key_values = out_full.past_key_values
        nfe += 1

        # Build a replace_position tensor indicating the block range (static slice)
        replace_position = torch.zeros_like(x, dtype=torch.bool)
        replace_position[:, s:e] = True  # boolean mask (not a dynamic slice bound)

        # Step 0: do an initial transfer on the full logits
        global_mask_index = (x == mask_id)
        # Do not touch beyond current block in this phase
        global_mask_index[:, e:] = False

        if factor is None:
            quota0 = None if threshold is not None else num_transfer_tokens[:, 0]  # (B,)
            x0, transfer_index = get_transfer_index(
                out_full.logits, temperature, remasking, global_mask_index, x, quota0, threshold
            )
        else:
            x0, transfer_index = get_transfer_index_dynamic(
                out_full.logits, temperature, remasking, global_mask_index, x, None, factor
            )

        # In-place update via torch.where (no tensor-slice assignment with mask)
        x = torch.where(transfer_index, x0, x)
        nfe += 1  # counted initial + this update

        # 2) Semi-autoregressive refinement, fixed number of steps (graph-friendly)
        #    Each iteration runs on the current block with KV-cache and replace_position
        for i in range(1, steps_per_block):
            # Evaluate logits only for current block with cache
            logits_blk = model(
                x[:, s:e], past_key_values=past_key_values, use_cache=True, replace_position=replace_position
            ).logits  # shape expected by get_transfer_index*

            # Mask and quota for this step (all tensor ops)
            mask_blk = (x[:, s:e] == mask_id)  # (B, block_length)

            if factor is None:
                quota_i = None if threshold is not None else num_transfer_tokens[:, i]  # (B,)
                x0_blk, transfer_idx_blk = get_transfer_index(
                    logits_blk, temperature, remasking, mask_blk, x[:, s:e], quota_i, threshold
                )
            else:
                x0_blk, transfer_idx_blk = get_transfer_index_dynamic(
                    logits_blk, temperature, remasking, mask_blk, x[:, s:e], None, factor
                )

            # Merge back into x[:, s:e] using torch.where (no masked slice assignment)
            blk_old = x[:, s:e]
            blk_new = torch.where(transfer_idx_blk, x0_blk, blk_old)
            x = torch.cat([x[:, :s], blk_new, x[:, e:]], dim=1)  # static concatenation

            nfe += 1

    return x



@ torch.no_grad()
def generate_with_remasking(model, prompt, attention_mask=None, parallel_tokens=1, gen_length=128, block_length=128, temperature=0., remasking = None,
             cfg_scale=0., mask_id=126336, logits_eos_inf=False, confidence_eos_eot_inf=False, tokenizer=None, log_to_file=False):
    '''
    Args:
        model: Mask predictor.
        prompt: A tensor of shape (1, L).
        steps: Sampling steps, less than or equal to gen_length.
        gen_length: Generated answer length.
        block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
        temperature: Categorical distribution sampling temperature.
        cfg_scale: Unsupervised classifier-free guidance scale.
        mask_id: The toke id of [MASK] is 126336.
        logits_eos_inf: Whether to set the logits of EOS token to -inf. See Appendix B.4 of LLaDA for details
        confidence_eos_eot_inf: Whether to set the confidence of EOS and EoT token to -inf. See Appendix B.4 of LLaDA for details
        tokenizer: Optional tokenizer for logging decoded text at each step.
        log_to_file: If True, logs generation details to a timestamped file.
    '''
    # Setup logging to file if enabled
    steps = gen_length // parallel_tokens
    log_file = None
    log_filename = None
    if log_to_file:
        log_filename = f"generation_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
        log_file = open(log_filename, 'w')
        log_file.write(f"Generation started at {datetime.now()}\n")
        log_file.write(f"Parameters: steps={steps}, gen_length={gen_length}, block_length={block_length}, temperature={temperature}, cfg_scale={cfg_scale}\n")
        log_file.write("="*80 + "\n\n")

    x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()
    #print(f"x.shape: {x.shape}")

    if attention_mask is not None:
        attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1)

    prompt_index = (x != mask_id)

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps = steps // num_blocks

    for num_block in range(num_blocks):
        block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens_remasking(block_mask_index, steps, parallel_tokens)
        # Storage for cumulative confidence scores across steps
        # This accumulates confidence scores from selected tokens in previous steps
        cumulative_confidence = None

        for i in range(steps):
            if log_file:
                log_file.write(f"\n--- Step {i}, Block {num_block} ---\n")
                log_file.flush()
            mask_index = (x == mask_id)
            if cfg_scale > 0.:
                un_x = x.clone()
                un_x[prompt_index] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                if attention_mask is not None:
                    attention_mask_ = torch.cat([attention_mask, attention_mask], dim=0)
                logits = model(x_, attention_mask=attention_mask_).logits
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model(x, attention_mask=attention_mask).logits

            if logits_eos_inf:
                logits[:, :, 126081] = -torch.inf

            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1) # b, l

            if confidence_eos_eot_inf:
                logits_with_noise[:, :, 126081] = logits[:, :, 126348] = -torch.inf

            # Calculate confidence scores based on softmax probabilities
            p = F.softmax(logits, dim=-1)
            x0_p = torch.squeeze(
                torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l

            x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
            
            #print(f"x0.shape: {x0.shape}")

            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -np.inf)
            #print(f"Confidence.shape: {confidence.shape}")
            
            # Combine current step confidence with cumulative confidence from previous steps
            if cumulative_confidence is not None:
                # Combine: keep previous confidence for already-selected tokens, add new confidence for masked positions
                # cumulative_confidence has confidence for previously selected tokens
                # confidence has confidence for current predictions on masked positions
                combined_confidence = torch.maximum(cumulative_confidence, confidence)
            else:
                # First step: use current confidence only
                combined_confidence = confidence

            # Select top k tokens based on combined confidence
            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            for j in range(combined_confidence.shape[0]):
                _, select_index = torch.topk(combined_confidence[j], k=num_transfer_tokens[j, i])
                transfer_index[j, select_index] = True

            # Track and perform remasking: positions that were previously selected but are now dropped
            prev_selected = (cumulative_confidence > -np.inf) if cumulative_confidence is not None else torch.zeros_like(transfer_index, dtype=torch.bool)
            dropped = prev_selected & ~transfer_index  # previously selected but not selected now

            # Remask dropped positions
            x[dropped] = mask_id

            # Update with new selected tokens
            x[transfer_index] = x0[transfer_index]

            # Update cumulative confidence: keep confidence for selected positions, set others to -inf
            cumulative_confidence = torch.where(transfer_index, combined_confidence, torch.full_like(combined_confidence, -np.inf))

            # Log current generation state
            if log_file and tokenizer is not None:
                log_file.write(f"\nCurrent generation state after step {i}:\n")
                for j in range(x.shape[0]):
                    # Decode the generated portion (excluding prompt)
                    generated_tokens = x[j, prompt.shape[1]:]
                    decoded = tokenizer.decode(generated_tokens, skip_special_tokens=False)
                    log_file.write(f"Sample {j}: {decoded}\n")

                    # Show which positions are still masked
                    masked_positions = torch.where(generated_tokens == mask_id)[0].tolist()
                    log_file.write(f"  Masked positions: {masked_positions}\n")
                    log_file.write(f"  Num masked: {len(masked_positions)}/{gen_length}\n")
                log_file.write("\n")
                log_file.flush()

    if log_file:
        log_file.write(f"\n\nGeneration completed at {datetime.now()}\n")
        log_file.write(f"Log saved to: {log_filename}\n")
        log_file.close()
        print(f"\nGeneration log saved to: {log_filename}")

    return x

@torch.no_grad()
def infilling(
    model: transformers.PreTrainedModel,
    tokenizer: transformers.PreTrainedTokenizer,
    inputs_with_blanks: list[torch.Tensor],
    scheduler: BaseAlphaScheduler = LinearAlphaScheduler(),
    parallel_tokens: int = 4,
    max_new_tokens: int = 256,
    block_length: int | None = None,
    temperature: float = 0.0,
    cfg_scale: float = 0.0,
    cfg_keep_tokens: list = None,
    remasking: str = "random",
    return_dict_in_generate: bool = False,
    stochastic_transfer: bool = False,
) -> torch.Tensor | dict:
    """
    Fill in-place the <|mdm_mask|> tokens contained in `inputs_with_blanks`.
    The whole (padded) sequence is split into block windows of length
    `block_length`; within each window we progressively "unmask" positions
    according to the scheduler and chosen remasking strategy.

    Notes:
    - Right padding uses EOS.
    - CFG masks out *originally known* (non-mask, non-EOS) tokens in the
      unconditional branch, identical to `generate`.
    - Only masked positions are ever updated; non-mask tokens are left intact.
    """
    # TODO: attention mask to avoid looking at the padding eos
    #       (short sequences in the batch).
    device = model.device
    steps = max_new_tokens // parallel_tokens


    tokenizer.mask_token = "<|mdm_mask|>"
    tokenizer.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
    tokenizer.chat_template = """
        {% set loop_messages = messages -%}
        {%- for message in loop_messages %}
        {%- if loop.index0 == 0 -%}{{ bos_token }}{%- endif -%}
        <|start_header_id|>{{ message['role'] }}<|end_header_id|>
        
        {{ message['content'] | trim }}<|eot_id|>
        {%- endfor -%}
        {%- if add_generation_prompt and (loop_messages | length == 0 or loop_messages[-1]['role'] != 'assistant') %}
        <|start_header_id|>assistant<|end_header_id|>
        
        {% endif %}
        """.lstrip()
        
    mask_id = tokenizer.mask_token_id
    eos_id = tokenizer.eos_token_id

    # ----- Build canvas: right-pad with EOS to the max length in the batch -----
    B = len(inputs_with_blanks)
    seq_lens = [t.shape[0] for t in inputs_with_blanks]
    T = max(seq_lens)

    # Default to a single block spanning the whole sequence
    if block_length is None:
        block_length = T

    assert 1 <= block_length
    assert 1 <= steps

    x = torch.full((B, T), eos_id, dtype=torch.long, device=device)
    for i, t in enumerate(inputs_with_blanks):
        x[i, : seq_lens[i]] = t

    # Tokens that were *given* at the start (non-mask, non-EOS).
    # These will be masked in the unconditional forward pass for CFG.
    # Tokens from `cfg_keep_tokens` should *not* be treated as "given" for CFG
    unmasked_index = (x != mask_id) & (x != eos_id)
    if not (cfg_keep_tokens is None or len(cfg_keep_tokens) == 0):
        keep_mask = torch.isin(x, torch.as_tensor(cfg_keep_tokens, device=device))
        unmasked_index = unmasked_index & ~keep_mask

    # ----- Blockwise schedule over the *entire* (padded) sequence -----
    num_blocks = math.ceil(T / block_length)
    steps_per_block = math.ceil(steps / num_blocks)
    effective_steps_per_block: list[int] = []

    for b in range(num_blocks):
        start = b * block_length
        stop = min(start + block_length, T)

        # Per-sample view of which positions in this block are masks
        block_mask_index = torch.zeros(
            (B, block_length), dtype=torch.bool, device=device
        )
        widths = []
        for j in range(B):
            # Width limited by sample's true length and sequence end
            width = max(0, min(seq_lens[j], stop) - start)
            widths.append(width)
            if width > 0:
                block_mask_index[j, :width] = x[j, start : start + width] == mask_id

        # Decide how many tokens to reveal at each step in this block
        num_transfer_tokens = get_num_transfer_tokens(
            mask_index=block_mask_index,
            steps=steps_per_block,
            scheduler=scheduler,
            stochastic=stochastic_transfer,
        )

        # Some blocks may have no masks => effective_steps == 0
        effective_steps = num_transfer_tokens.size(1)
        effective_steps_per_block.append(effective_steps)

        for s in range(effective_steps):
            mask_index_full = x == mask_id

            # ----- Forward pass (+ optional CFG) -----
            if cfg_scale > 0.0:
                un_x = x.clone()
                un_x[unmasked_index] = mask_id  # mask out originally known tokens
                x_ = torch.cat([x, un_x], dim=0)
                logits = model(x_).logits
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model(x).logits

            # Greedy with optional Gumbel-Max noise
            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1)  # [B, T]

            # Confidence used for choosing which masks to commit this step
            if remasking == "low_confidence":
                p = F.softmax(logits, dim=-1)
                x0_p = torch.gather(p, dim=-1, index=x0.unsqueeze(-1)).squeeze(
                    -1
                )  # [B, T]
            elif remasking == "random":
                x0_p = torch.rand((B, T), device=device)
            else:
                raise NotImplementedError(remasking)

            # Restrict selection to the *current* block only
            for j in range(B):
                end_j = start + widths[j]
                # Outside current block => impossible to select
                x0_p[j, :start] = -np.inf
                x0_p[j, end_j:] = -np.inf

            # Only consider currently-masked positions as candidates
            x0 = torch.where(mask_index_full, x0, x)
            confidence = torch.where(mask_index_full, x0_p, -np.inf)

            # Pick exactly num_transfer_tokens[j, s] positions per sample
            transfer_index = torch.zeros_like(x, dtype=torch.bool)
            for j in range(B):
                k = int(num_transfer_tokens[j, s].item())
                if k > 0:
                    _, select_idx = torch.topk(confidence[j], k=k)
                    transfer_index[j, select_idx] = True

            # Commit selected predictions into the canvas
            x[transfer_index] = x0[transfer_index]

    if not return_dict_in_generate:
        return x
    else:
        return {
            "effective_steps_per_block": effective_steps_per_block,
            "sequences": x,
        }
        
        
def generate_then_infill(
    model: transformers.PreTrainedModel,
    tokenizer: transformers.PreTrainedTokenizer,
    prompts: list[torch.Tensor],
    generate_parallel_tokens: int = 1,
    infill_parallel_tokens: int = 4,
    max_new_tokens: int = 256,
    max_length: int = 1024,
    block_length: int = 32,
    temperature: float = 1.0,
    cfg_scale: float = 0.0,
    cfg_keep_tokens: list = None,
    remasking: str = "low_confidence",
    return_dict_in_generate: bool = False,
    early_stop_step: int | None = None,
    ):
    
    """
    Work in progress.
    First generate with early stopping, then infill the remaining masks.
    
    """
    
   
    # First, generate with early stopping
    gen_output = generate(
        model=model,
        tokenizer=tokenizer,
        prompts=prompts,
        parallel_tokens=generate_parallel_tokens,
        max_new_tokens=max_new_tokens,
        max_length=max_length,
        block_length=block_length,
        temperature=temperature,
        cfg_scale=cfg_scale,
        cfg_keep_tokens=cfg_keep_tokens,
        remasking=remasking,
        return_dict_in_generate=True,
        early_stop_step=early_stop_step
    )
    
    
    infill_out = infilling(
        model=model,
        tokenizer=tokenizer,
        inputs_with_blanks=[gen_output['sequences'][i] for i in range(len(prompts))],
        scheduler=LinearAlphaScheduler(),
        parallel_tokens=infill_parallel_tokens,
        max_new_tokens=max_new_tokens,
        block_length=block_length,
        temperature=temperature,
        cfg_scale=cfg_scale,
        cfg_keep_tokens=cfg_keep_tokens,
        remasking=remasking,
        return_dict_in_generate=return_dict_in_generate,
        stochastic_transfer=False
    )

    decoded = tokenizer.decode(infill_out['sequences'][0], skip_special_tokens=True)

    return decoded


def chat():
    device = 'cuda'
    
    # Load model with optimized settings
    model = AutoModel.from_pretrained('./LLaDA-8B-Instruct', 
                                     trust_remote_code=True, 
                                     torch_dtype=torch.bfloat16,
                                     low_cpu_mem_usage=True).to(device).eval()
    
    tokenizer = AutoTokenizer.from_pretrained('./LLaDA-8B-Instruct', trust_remote_code=True)

    # Reduce default parameters for faster generation
    gen_length = 128
    parallel_tokens = 1  # Reduced from 128
    mini_batch = 8
    print('*' * 66)
    print(f'**  Answer Length: {gen_length}  |  Sampling Steps: {steps}  **')
    print('*' * 66)

    conversation_num = 0
    prompt = None
    
    while True:
        user_input = input("Enter your question: ")

        m = [{"role": "user", "content": user_input}]
        user_input = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
        input_ids = tokenizer(user_input, return_tensors="pt").input_ids.to(device)

        if conversation_num == 0:
            prompt = input_ids
        else:
            prompt = torch.cat([prompt, input_ids[:, 1:]], dim=1)

        # Pre-allocate CUDA memory
        torch.cuda.empty_cache()
        
        # Use optimized generation
        out = generate(model, tokenizer, [prompt[0]], 
                      parallel_tokens=parallel_tokens, 
                      max_new_tokens=gen_length, 
                      block_length=32, 
                      temperature=0., 
                      cfg_scale=0., 
                      remasking='low_confidence')

        answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]
        print(f"Bot's reply: {answer}")

        # Remove the <EOS> token
        prompt = out[out != 126081].unsqueeze(0)
        conversation_num += 1
        print('-----------------------------------------------------------------------')

def batched_chat():
    device = 'cuda'
    
    # Load model with optimized settings
    model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', 
                                     trust_remote_code=True, 
                                     torch_dtype=torch.bfloat16,
                                     low_cpu_mem_usage=True).to(device).eval()
    
    tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)

    # Reduce default parameters for faster generation
    gen_length = 256
    parallel_tokens = 1  # Reduced from 128
    block_length = 256
    batch_n = 32  # NEW: Mini Batch for evaluation

    print('*' * 66)
    print(f'**  Answer Length: {gen_length} | Block Length: {block_length} |  Sampling Steps: {steps} | Batch Size: {batch_n} **')
    print('*' * 66)

    conversation_num = 0
    prompt = None
    
    while True:
        user_input = input("Enter your question: ")

        m = [{"role": "user", "content": user_input}]
        user_input = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)
        input_ids = tokenizer(user_input, return_tensors="pt").input_ids.to(device)

        if conversation_num == 0:
            prompt = input_ids
        else:
            prompt = torch.cat([prompt, input_ids[:, 1:]], dim=1)

        # Pre-allocate CUDA memory
        torch.cuda.empty_cache()
        
        # >>> NEW: repeat the SAME prompt batch_n times for batched inference
        batched_prompts = [prompt[0] for _ in range(batch_n)]
        start_time = time.time()
        # Use optimized generation (unchanged args)
        out = generate(model, tokenizer, batched_prompts, 
                       parallel_tokens=parallel_tokens, 
                       max_new_tokens=gen_length, 
                       block_length=block_length, 
                       temperature=0.7, 
                       cfg_scale=0., 
                       remasking='low_confidence')
        end_time = time.time()
        # ---- NEW: decode and print ALL batch outputs for verification ----
        gen_tail = out[:, prompt.shape[1]:]                 # (B, gen_len)
        answers = tokenizer.batch_decode(gen_tail.tolist(), # move to CPU list for safety
                                         skip_special_tokens=True)
        print("\n=== Batched outputs ===")
        for i, ans in enumerate(answers):
            print(f"[{i}] {ans}")
        print("=== end ===\n")
        print(f"Generation time for batch of {batch_n}: {end_time - start_time:.2f} seconds")
        # # Decode only the first sample's continuation (keeps behavior identical)
        # answer = tokenizer.batch_decode(out[:1, prompt.shape[1]:], skip_special_tokens=True)[0]
        # print(f"Bot's reply: {answer}")

        # Keep your original EOS removal logic, applied to the first row
        prompt = out[0][out[0] != 126081].unsqueeze(0)
        conversation_num += 1
        print('-----------------------------------------------------------------------')

def message_chat():
    
    messages = [
        {"role": "user", "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"},  
        {"role": "user", "content": "Gail has two fish tanks. The first tank is twice the size of the second tank. There are 48 gallons of water in the first tank. She follows the rule of one gallon of water per inch of fish. If she keeps two-inch fish in the second tank and three-inch fish in the first tank, how many more fish would Gail have in the first tank than the second tank if one of the first tank fish eats another?"}, 
        
    ]
    
    """
    answers:
    72
    3
    """
    
    device = 'cuda'
    
    # Load model with optimized settings
    model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', 
                                     trust_remote_code=True, 
                                     torch_dtype=torch.bfloat16,
                                     low_cpu_mem_usage=True).to(device).eval()
    
    tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)

    # Reduce default parameters for faster generation
    gen_length = 256
    parallel_tokens = 1  # Reduced from 128
    block_length = 32
    
    print('*' * 66)
    print(f'**  Answer Length: {gen_length}     |  Sampling Steps: {gen_length// parallel_tokens}  **')
    print('*' * 66)

    # Apply chat template to the messages
    input_ids_list = [
    tokenizer.apply_chat_template(
        [m],  # Wrap each message in a list
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
    )[0].to(model.device)
    for m in messages
    ]
    
    # Tokenize the formatted messages
    out = generate(
        model,
        tokenizer,
        input_ids_list,
        parallel_tokens=parallel_tokens,
        max_new_tokens=gen_length,
        block_length=block_length,
        temperature=0.7,
        remasking='low_confidence',
        )

    # Pre-allocate CUDA memory
    torch.cuda.empty_cache()

    
    # Decode only the generated tokens (skip the prompt)
    generations = []
    for i, (input_ids, output_ids) in enumerate(zip(input_ids_list, out)):
        # Slice to get only the newly generated tokens
        generated_tokens = output_ids[len(input_ids):]
        decoded = tokenizer.decode(generated_tokens, skip_special_tokens=False)
        generations.append(decoded)
    
    for i, o in enumerate(generations):
        print("\n" + "-" * 80)
        print(f"[Case {i}]")
        print("-" * 80)
        print(o.strip() if o.strip() else "<empty>")



def infill_chat():
    
    device = 'cuda'
    model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', 
                                     trust_remote_code=True, 
                                     torch_dtype=torch.bfloat16,
                                     low_cpu_mem_usage=True).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)
    tokenizer.mask_token = "<|mdm_mask|>"
    tokenizer.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
    tokenizer.chat_template = """
        {% set loop_messages = messages -%}
        {%- for message in loop_messages %}
        {%- if loop.index0 == 0 -%}{{ bos_token }}{%- endif -%}
        <|start_header_id|>{{ message['role'] }}<|end_header_id|>

        {{ message['content'] | trim }}<|eot_id|>
        {%- endfor -%}
        {%- if add_generation_prompt and (loop_messages | length == 0 or loop_messages[-1]['role'] != 'assistant') %}
        <|start_header_id|>assistant<|end_header_id|>

        {% endif %}
        """.lstrip()

    parallel_tokens = 4  # Reduced from 128
    block_length = 32
    
    masked_inputs = [
        [
            {"role": "user", "content": tokenizer.mask_token * 20},
            {
                "role": "assistant",
                "content": "Sorry, I do not have answer to this question.",
            },
        ],
        [
            {"role": "user", "content": "Please write an educational python function."},
            {
                "role": "assistant",
                "content": "def hello_" + tokenizer.mask_token * 20 + " return",
            },
        ],
    ]

    fib_input_ids_list = [
        tokenizer.apply_chat_template(
            m,
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
        )[0].to(model.device)
        for m in masked_inputs
    ]
    
    out = infilling(
        model,
        tokenizer,
        fib_input_ids_list,
        temperature=0.7,
        block_length=block_length,
        parallel_tokens=parallel_tokens,
        remasking='low_confidence',
    )
    
    filled = tokenizer.batch_decode(out)
    
    print(filled)

    for i, (ids, f) in enumerate(zip(fib_input_ids_list, filled)):
        print("\n" + "-" * 80)
        print(f"[Case {i}]")
        print("-" * 80)
        print("[Masked]:\n" + tokenizer.decode(ids))
        print("\n[Filled]:\n" + (f.strip() if f.strip() else "<empty>"))

    print("\n" + "=" * 80 + "\n")
    
def remasking_chat():
    device = 'cuda'

    model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)

    # The LLaDA architecture theoretically supports both left-padding and right-padding. 
    # However, the sampling code implementation is simpler with left-padding.
    if tokenizer.padding_side != 'left':
        tokenizer.padding_side = 'left'

    # If the padding ID equals the mask ID, you need to modify our generate function to achieve correct inference.
    assert tokenizer.pad_token_id != 126336

    prompts = [ "Hasan is packing up his apartment because he’s moving across the country for a new job. He needs to ship several boxes to his new home. The movers have asked that Hasan avoid putting more than a certain weight in pounds in any cardboard box. The moving company has helpfully provided Hasan with a digital scale that will alert him if a package is too heavy. Hasan is in the kitchen, and he fills a cardboard box with 38 dinner plates. When he checks the box, the scale reports his box is too heavy. Hasan knows each of his plates weighs 10 ounces. He removes a single plate from the box and checks the movers’ scale again. The scale reports his box is still too heavy. Hasan repeats the process again and again. When he has removed enough plates, the movers’ scale shows the box is now an acceptable weight for shipping. Hasan deduces that each shipping box can hold 20 pounds before the scale says the box is too heavy. How many plates did Hasan need to remove from the shipping box?",
              "Joy can read 8 pages of a book in 20 minutes. How many hours will it take her to read 120 pages?",]
            #  "Randy has 60 mango trees on his farm. He also has 5 less than half as many coconut trees as mango trees. How many trees does Randy have in all on his farm?"]
    prompts = prompts * 32 # batch size 1 for testing
    
    # Add special tokens for the Instruct model. The Base model does not require the following two lines.
    messages = [{"role": "user", "content": prompt} for prompt in prompts]
    prompts = [tokenizer.apply_chat_template([message], add_generation_prompt=True, tokenize=False) for message in messages]
    
    test = tokenizer(prompts[0], return_tensors='pt', padding=True, add_special_tokens=False)
    input_ids_test = test['input_ids'].to(device)
    #print(input_ids_test.shape)

    encoded_outputs = tokenizer(
        prompts,
        add_special_tokens=False,
        padding=True,
        return_tensors="pt"
    )
    input_ids = encoded_outputs['input_ids'].to(device)
    attention_mask = encoded_outputs['attention_mask'].to(device)
    
    start  = time.time()

    out = generate_with_remasking(model, input_ids, attention_mask, parallel_tokens=2, gen_length=256, block_length=32, temperature=0., cfg_scale=0., tokenizer=tokenizer, log_to_file=False)
    
    end = time.time() - start
    print(f"Generation time: {end:.2f} seconds")
    
    output = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)
    # for o in output:
    #     print(o)
    #     print('-' * 50)

def get_transfer_index(
    logits: torch.Tensor,
    temperature: float,
    remasking: str,
    mask_index: torch.Tensor,   # (B, L) bool
    x: torch.Tensor,            # (B, L) long
    num_transfer_tokens,        # (B,) or (B,1) long tensor, or None when threshold is used
    threshold: float = None,
):
    """
    Returns:
        x0: (B, L) long — proposed tokens
        transfer_index: (B, L) bool — which positions to update this step
    """
    # 1) Sample proposal x0
    # Gumbel-noise for exploration; if temperature==0, add_gumbel_noise should no-op
    logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
    x0 = torch.argmax(logits_with_noise, dim=-1)  # (B, L), long

    # 2) Confidence for chosen tokens (or random)
    if remasking == "low_confidence":
        # Use higher precision for softmax stability
        p = F.softmax(logits.to(torch.float64), dim=-1)
        x0_p = torch.gather(p, dim=-1, index=x0.unsqueeze(-1)).squeeze(-1)  # (B, L), float64
    elif remasking == "random":
        x0_p = torch.rand(x0.shape, device=x0.device, dtype=torch.float64)  # (B, L)
    else:
        raise NotImplementedError(remasking)

    # Only modify masked spots; keep others as original x and set their confidence to -inf
    x0 = torch.where(mask_index, x0, x)

    neg_inf = torch.tensor(torch.finfo(x0_p.dtype).min, device=x0_p.device, dtype=x0_p.dtype)
    confidence = torch.where(mask_index, x0_p, neg_inf)  # (B, L)

    # 3) Pick positions to transfer (vectorized)
    if threshold is not None:
        # Transfer all masked positions whose confidence >= threshold
        # (No top-k; purely threshold-based)
        transfer_index = mask_index & (confidence >= threshold)
        return x0, transfer_index

    # Else: per-row top-k with varying k (num_transfer_tokens), fully batched
    if num_transfer_tokens is None:
        raise ValueError("num_transfer_tokens must be a tensor when threshold is None.")

    # Ensure shape (B,) long
    if num_transfer_tokens.dim() == 2 and num_transfer_tokens.size(1) == 1:
        num_transfer_tokens = num_transfer_tokens.squeeze(1)
    num_transfer_tokens = num_transfer_tokens.to(dtype=torch.long, device=confidence.device)
    num_transfer_tokens = torch.clamp(num_transfer_tokens, min=0)

    # Sort confidences descending (masked positions are valid; others are -inf)
    # idx: (B, L) gives positions in original sequence sorted by confidence
    values, idx = torch.sort(confidence, dim=1, descending=True)

    B, L = confidence.shape
    # Build a mask that is True for the first k[b] columns in each row (sorted order)
    cols = torch.arange(L, device=confidence.device).unsqueeze(0).expand(B, L)   # (B, L)
    k_expanded = num_transfer_tokens.unsqueeze(1).expand(B, L)                   # (B, L)
    select_sorted = cols < k_expanded                                            # (B, L) bool

    # Scatter the sorted True/False back to original column order
    # Use integer scatter then cast to bool (scatter_ on bool can be finicky across versions)
    transfer_int = torch.zeros(B, L, device=confidence.device, dtype=torch.int8) # (B, L)
    transfer_int = transfer_int.scatter(1, idx, select_sorted.to(torch.int8))
    transfer_index = transfer_int.bool() & mask_index  # ensure we never select unmasked

    return x0, transfer_index

def get_transfer_index_dynamic(logits, temperature, remasking, mask_index, x, num_transfer_tokens, factor=1):
    logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
    x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
    if remasking == 'low_confidence':
        p = F.softmax(logits.to(torch.float64), dim=-1)
        x0_p = torch.squeeze(
            torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
    elif remasking == 'random':
        x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
    else:
        raise NotImplementedError(remasking)
    
    x0 = torch.where(mask_index, x0, x)
    confidence = torch.where(mask_index, x0_p, -np.inf)

    transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
    num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
    
    for j in range(confidence.shape[0]):
        ns=list(range(1,num_transfer_tokens[j]+1))
        es=[factor/(n+1) for n in ns]
        threshs=[1-e for e in es]

        # at least one token is transferred
        threshs[0]=-1
        sorted_confidence=torch.sort(confidence[j][mask_index[j]],dim=-1,descending=True)[0]
        assert len(sorted_confidence)==len(threshs)
        for top_i in range(len(threshs)):
            if sorted_confidence[top_i]<threshs[top_i]:
                break

        if top_i == 0 or top_i == len(threshs)-1:
            top_i+=1

        _, select_index = torch.topk(confidence[j], k=top_i)
        transfer_index[j, select_index] = True

    return x0, transfer_index
    
def main():
    device = 'cuda'

    # model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
    # tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)

    model = LLaDAModelLM.from_pretrained('/home/hans/.cache/huggingface/hub/models--GSAI-ML--LLaDA-8B-Instruct/snapshots/9275bf8f5a5687507189baf4657e91c51b2be338', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
    tokenizer = AutoTokenizer.from_pretrained('/home/hans/.cache/huggingface/hub/models--GSAI-ML--LLaDA-8B-Instruct/snapshots/9275bf8f5a5687507189baf4657e91c51b2be338', trust_remote_code=True)

    # The LLaDA architecture theoretically supports both left-padding and right-padding. 
    # However, the sampling code implementation is simpler with left-padding.
    if tokenizer.padding_side != 'left':
        tokenizer.padding_side = 'left'

    # If the padding ID equals the mask ID, you need to modify our generate function to achieve correct inference.
    assert tokenizer.pad_token_id != 126336

    # prompts = [ "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?",
    #          "Joy can read 8 pages of a book in 20 minutes. How many hours will it take her to read 120 pages?",
    #          "Randy has 60 mango trees on his farm. He also has 5 less than half as many coconut trees as mango trees. How many trees does Randy have in all on his farm?"]
    
    prompts = [ "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?",]

    # Add special tokens for the Instruct model. The Base model does not require the following two lines.
    messages = [{"role": "user", "content": prompt} for prompt in prompts]
    prompts = [tokenizer.apply_chat_template([message], add_generation_prompt=True, tokenize=False) for message in messages]

    encoded_outputs = tokenizer(
        prompts,
        add_special_tokens=False,
        padding=True,
        return_tensors="pt"
    )
    input_ids = encoded_outputs['input_ids'].to(device)
    attention_mask = encoded_outputs['attention_mask'].to(device)

    out = generate_with_prefix_cache(model, input_ids, attention_mask, parallel_tokens=2, gen_length=128, block_length=32, temperature=0., cfg_scale=0., tokenizer=tokenizer, log_to_file=True, threshold=None, factor=None)
    output = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)
    for o in output:
        print(o)
        print('-' * 50)

if __name__ == "__main__":
    # remasking_chat()
    main()
