import time
import math
import torch
import torch.nn.functional as F
import numpy as np
from accelerate.utils import broadcast


def add_gumbel_noise(logits, temperature, dtype):
    """
    The Gumbel max is a method for sampling categorical distributions.
    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
    Thus, we use float64.
    """
    if temperature == 0.0:
        return logits
    logits = logits.to(dtype)
    noise = torch.rand_like(logits, dtype=dtype)
    gumbel_noise = (-torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def get_num_transfer_tokens(mask_index, steps):
    """
    Precompute the number of tokens to transition at each step.
    In each block, the number of token to be unmask at each step
    = total_masked_tokens / steps
    """
    mask_num = mask_index.sum(dim=1, keepdim=True)
    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = base.expand(-1, steps).clone()

    if remainder.sum() > 0:
        indices = torch.arange(steps, device=mask_index.device)
        mask = indices.unsqueeze(0) < remainder
        num_transfer_tokens[mask] += 1

    return num_transfer_tokens.to(torch.int64)


@torch.no_grad()
def dynamic_generate(
    model,
    prompt,
    tokenizer,
    steps=128,
    gen_length=256,
    temperature=0.0,
    cfg_scale=0.0,
    remasking="low_confidence",
    mask_id=126336,
):
    """
    Final Optimized Dynamic Generation (While-Loop Version).
    Includes:
    1. Dynamic Block Truncation (Detection separated from Execution).
    2. Normalized Block Entropy Reward Calculation.
    """
    block_token_id = None
    target_str = "block"
    try:
        ids = tokenizer.encode(target_str, add_special_tokens=False)
        if len(ids) == 1:
            block_token_id = ids[0]
    except Exception as e:
        print(f"Error encoding target string '{target_str}': {e}")
        pass

    with torch.cuda.amp.autocast(enabled=True):
        bs = prompt.shape[0]
        dtype = model.dtype
        device = prompt.device

        total_seq_len = prompt.shape[1] + gen_length

        x = torch.full(
            (bs, total_seq_len),
            mask_id,
            dtype=torch.long,
            device=device,
        )
        x[:, : prompt.shape[1]] = prompt.clone()

        prompt_index = x != mask_id
        initial_length = gen_length // 2
        num_standard_blocks = max(1, gen_length // initial_length)
        steps_per_block = max(1, steps // num_standard_blocks)

        sample_starts = [prompt.shape[1]] * bs
        current_round_block_sizes = [initial_length] * bs
        all_executed_block_sizes = [[] for _ in range(bs)]

        block_entropies = [[] for _ in range(bs)]

        last_step_entropy = torch.zeros(
            (bs, total_seq_len), dtype=torch.float32, device=device
        )

        loop_counter = 0
        MAX_LOOPS = gen_length
        start_time = time.time()

        # Generating through dynamic blocks
        while True:
            if loop_counter >= MAX_LOOPS:
                break
            loop_counter += 1

            active_mask = torch.zeros_like(x, dtype=torch.bool)
            active_indices = []
            samples_pending = False

            for b in range(bs):
                start = sample_starts[b]
                if start >= total_seq_len:
                    continue

                samples_pending = True
                active_indices.append(b)

                end = min(start + initial_length, total_seq_len)
                active_mask[b, start:end] = True
                current_round_block_sizes[b] = end - start

            if not samples_pending:
                break

            target_mask_index = (x == mask_id) & active_mask
            num_transfer_tokens = get_num_transfer_tokens(
                target_mask_index, steps_per_block
            )

            # Parallel generation within the current dynamic block
            for i in range(steps_per_block):
                if num_transfer_tokens[active_indices, i].sum() == 0:
                    break

                mask_index = x == mask_id

                if cfg_scale > 0.0:
                    un_x = x.clone()
                    un_x[prompt_index] = mask_id
                    x_ = torch.cat([x, un_x], dim=0)
                    logits = model(x_).logits
                    logits, un_logits = torch.chunk(logits, 2, dim=0)
                    logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
                else:
                    logits = model(x).logits

                probs = F.softmax(logits.float(), dim=-1)
                current_entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=-1)

                if len(active_indices) > 0:
                    last_step_entropy[active_mask] = current_entropy[active_mask]

                logits_with_noise = add_gumbel_noise(logits, temperature, dtype=dtype)
                x0 = torch.argmax(logits_with_noise, dim=-1)

                if remasking == "low_confidence":
                    p = F.softmax(logits, dim=-1)
                    x0_p = torch.gather(p, dim=-1, index=x0.unsqueeze(-1)).squeeze(-1)
                elif remasking == "random":
                    x0_p = torch.rand(x0.shape, device=device)
                else:
                    raise NotImplementedError(remasking)

                x0_p[~active_mask] = -np.inf
                x0 = torch.where(mask_index, x0, x)
                confidence = torch.where(
                    mask_index, x0_p, torch.tensor(-np.inf, device=device)
                )

                for j in range(confidence.shape[0]):
                    num_tokens = num_transfer_tokens[j, i].item()
                    if num_tokens > 0:
                        _, select_indices = torch.topk(confidence[j], k=num_tokens)
                        valid_mask = confidence[j, select_indices] > -float("inf")
                        final_indices = select_indices[valid_mask]
                        if final_indices.numel() > 0:
                            x[j, final_indices] = x0[j, final_indices]

                truncation_plan = {}
                if block_token_id is not None:
                    for b in active_indices:
                        start = sample_starts[b]

                        curr_len = current_round_block_sizes[b]
                        window_tokens = x[b, start : start + curr_len]
                        matches = (window_tokens == block_token_id).nonzero(
                            as_tuple=True
                        )[0]

                        valid_matches = matches[matches >= 1]
                        if len(valid_matches) > 0:
                            first_valid_idx = valid_matches[0].item()
                            truncation_plan[b] = first_valid_idx + 1

                # Dynamic end the block
                for b, new_len in truncation_plan.items():
                    if new_len == 0:
                        new_len = 1
                    original_len = current_round_block_sizes[b]
                    start = sample_starts[b]
                    if new_len < original_len:
                        new_end = start + new_len
                        old_end = start + original_len
                        x[b, new_end:old_end] = mask_id
                        current_round_block_sizes[b] = new_len
                        active_mask[b, new_end:old_end] = False

                        remaining_steps = steps_per_block - (i + 1)
                        if remaining_steps > 0:
                            current_masked_count = (
                                (x[b, start:new_end] == mask_id).sum().item()
                            )
                            base = current_masked_count // remaining_steps
                            rem = current_masked_count % remaining_steps
                            num_transfer_tokens[b, i + 1 : i + 1 + rem] = base + 1
                            num_transfer_tokens[b, i + 1 + rem :] = base

            # Dynamic end the block based on block token
            if block_token_id is not None:
                truncation_plan = {}

                for b in active_indices:
                    start = sample_starts[b]
                    curr_len = current_round_block_sizes[b]
                    window_tokens = x[b, start : start + curr_len]
                    matches = (window_tokens == block_token_id).nonzero(as_tuple=True)[
                        0
                    ]

                    valid_matches = matches[matches >= 1]
                    if len(valid_matches) > 0:
                        first_valid_idx = valid_matches[0].item()
                        truncation_plan[b] = first_valid_idx + 1

                for b, new_len in truncation_plan.items():
                    if new_len == 0:
                        new_len = 1
                    original_len = current_round_block_sizes[b]
                    start = sample_starts[b]

                    if new_len < original_len:
                        new_end = start + new_len
                        old_end = start + original_len

                        x[b, new_end:old_end] = mask_id
                        current_round_block_sizes[b] = new_len

            for b in range(bs):
                start = sample_starts[b]
                if start < total_seq_len:
                    actual_size = current_round_block_sizes[b]
                    all_executed_block_sizes[b].append(actual_size)

                    block_end = start + actual_size
                    avg_block_entropy = (
                        last_step_entropy[b, start:block_end].mean().item()
                    )
                    block_entropies[b].append(avg_block_entropy)

                    current_block = x[b, start : start + actual_size]
                    if (
                        tokenizer.eos_token_id is not None
                        and (current_block == tokenizer.eos_token_id).any()
                    ):
                        sample_starts[b] = total_seq_len
                    else:
                        sample_starts[b] += actual_size

        # Calculate entropy-based rewards
        entropy_rewards = torch.zeros(bs, device=device)
        target_num_blocks = 10
        for b in range(bs):
            entropies = block_entropies[b]
            num_blocks = len(entropies)

            if num_blocks >= target_num_blocks:
                quantity_score = 1.0
            else:
                quantity_score = math.log(num_blocks + 1) / math.log(
                    target_num_blocks + 1
                )

            if num_blocks >= 2:
                decrease_count = 0.0
                total_transitions = num_blocks - 1

                for k in range(1, num_blocks):
                    prev_ent = entropies[k - 1]
                    curr_ent = entropies[k]

                    if curr_ent < prev_ent:
                        decrease_count += 1.0
                trend_score = decrease_count / total_transitions

                entropy_rewards[b] = 0.5 * trend_score + 0.5 * quantity_score
            else:
                entropy_rewards[b] = 0.0

        return x, entropy_rewards
