import torch
import torch.nn.functional as F
import torch.distributed as dist


def add_gumbel_noise(logits, temperature):
    if temperature == 0.0:
        return logits
    logits = logits.to(torch.float32)
    noise = torch.rand_like(logits, dtype=torch.float32)
    gumbel_noise = (-torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


@torch.no_grad()
def generate(
    model,
    prompt,
    tokenizer,
    initial_gen_length=64,
    max_gen_length=2048,
    block_length=32,
    temperature=0.0,
    cfg_scale=0.0,
    high_conf_threshold=0.90,
    low_conf_threshold=0.10,
    expansion_factor=8,
    mask_id=126336,
    eos_token_id=126081,
    eos_confidence_threshold=0.5,
    expand_eos_confidence_threshold=0.9,
    eos_check_tokens=32,
):
    def _calculate_eos_confidence(logits, total_lengths, prompt_length, eos_check_tokens):
        if eos_token_id is None:
            return torch.zeros(logits.shape[0], device=logits.device)
        confidences = F.softmax(logits, dim=-1)
        predicted_tokens = torch.argmax(logits, dim=-1)
        batch_eos_confidences = []
        for i in range(logits.shape[0]):
            eos_confs_for_avg = []
            start_scan_pos = total_lengths[i].item() - 1
            end_scan_pos = prompt_length - 1
            for pos in range(start_scan_pos, end_scan_pos, -1):
                if len(eos_confs_for_avg) >= eos_check_tokens:
                    break
                if predicted_tokens[i, pos] == eos_token_id:
                    eos_confs_for_avg.append(confidences[i, pos, eos_token_id].item())
            avg_conf = sum(eos_confs_for_avg) / eos_check_tokens
            batch_eos_confidences.append(avg_conf)
        return torch.tensor(batch_eos_confidences, device=logits.device)

    with torch.autocast(device_type="cuda"):
        batch_size = prompt.shape[0]
        device = prompt.device
        prompt_length = prompt.shape[1]
        assert eos_token_id is not None
        gen_lengths = torch.full((batch_size,), initial_gen_length, dtype=torch.long, device=device)
        x = torch.full(
            (batch_size, prompt_length + initial_gen_length),
            mask_id,
            dtype=torch.long,
            device=device,
        )
        x[:, :prompt_length] = prompt.clone()
        prompt_index = x != mask_id
        if not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0:
            print("[Stage-1] Initial Length Adjustment")
        while True:
            total_lengths = prompt_length + gen_lengths
            max_len_pre = x.shape[1]
            arange_tensor_pre = torch.arange(max_len_pre, device=device).expand(batch_size, -1)
            attention_mask_pre = (arange_tensor_pre < total_lengths.unsqueeze(1)).long()
            logits_pre = model(x, attention_mask=attention_mask_pre).logits
            batch_eos_confidences = _calculate_eos_confidence(logits_pre, total_lengths, prompt_length, eos_check_tokens)
            sequences_to_expand = (batch_eos_confidences < eos_confidence_threshold) & (gen_lengths < max_gen_length)
            if not sequences_to_expand.any():
                if not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0:
                    print(f"All sequences' EOS confidence reach the threshold {eos_confidence_threshold} or max length.")
                break
            if not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0:
                 print(f"Some sequences' EOS confidence ({[round(c.item(), 4) for c in batch_eos_confidences]}) < {eos_confidence_threshold}. Expand initial length.")
            max_new_gen_len = gen_lengths[sequences_to_expand].max().item()
            new_gen_lengths = gen_lengths.clone()
            new_gen_lengths[sequences_to_expand] = torch.clamp(gen_lengths[sequences_to_expand] + expansion_factor, max=max_gen_length)
            if new_gen_lengths.max() <= gen_lengths.max():
                 if not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0:
                    print(f"WARNING: Cannot expand initial length further (already at max length: {max_gen_length}).")
                 break
            max_new_total_len = prompt_length + new_gen_lengths.max()
            new_x_tensor = torch.full((batch_size, max_new_total_len), eos_token_id, dtype=torch.long, device=device)
            for i in range(batch_size):
                original_total_len = prompt_length + gen_lengths[i].item()
                new_x_tensor[i, :original_total_len] = x[i, :original_total_len]
                if sequences_to_expand[i]:
                    new_total_len_i = prompt_length + new_gen_lengths[i].item()
                    new_x_tensor[i, original_total_len : new_total_len_i] = mask_id
            x = new_x_tensor
            gen_lengths = new_gen_lengths
        
        
        if not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0:
            print(f"[Stage-2] Iterative Denoising and Mask Insertion")

        current_pos = torch.full((batch_size,), prompt_length, dtype=torch.long, device=device)
        denoise_only_mode = torch.zeros(batch_size, dtype=torch.bool, device=device)
        
        while (current_pos < prompt_length + gen_lengths).any():
            total_lengths = prompt_length + gen_lengths
            x_before_step = x.clone()
            
            for i in range(batch_size):
                if gen_lengths[i] >= max_gen_length and not denoise_only_mode[i]:
                    if current_pos[i] < total_lengths[i]:
                        if not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0:
                            print(f"Sequence {i} has reached the max length {max_gen_length}. Entering denoise-only mode.")
                        denoise_only_mode[i] = True

            max_len = x.shape[1]
            arange_tensor = torch.arange(max_len, device=device).expand(batch_size, -1)
            attention_mask = (arange_tensor < total_lengths.unsqueeze(1)).long()
            if cfg_scale > 0.0:
                un_x = x.clone(); un_x[prompt_index] = mask_id
                un_attention_mask = attention_mask.clone()
                x_ = torch.cat([x, un_x], dim=0)
                attention_mask_ = torch.cat([attention_mask, un_attention_mask], dim=0)
                logits, un_logits = torch.chunk(model(x_, attention_mask=attention_mask_).logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model(x, attention_mask=attention_mask).logits

            predicted_tokens = torch.argmax(add_gumbel_noise(logits, temperature), dim=-1)
            confidences = F.softmax(logits, dim=-1)
            predicted_confidences = torch.gather(confidences, dim=-1, index=predicted_tokens.unsqueeze(-1)).squeeze(-1)
            batch_eos_confidences = _calculate_eos_confidence(logits, total_lengths, prompt_length, eos_check_tokens)

            block_mask = torch.zeros_like(x, dtype=torch.bool, device=device)
            for i in range(batch_size):
                if current_pos[i] >= total_lengths[i]: continue
                block_mask[i, current_pos[i]:min(current_pos[i] + block_length, total_lengths[i].item())] = True
            
            currently_masked = (x == mask_id)
            high_conf_indices = (predicted_confidences > high_conf_threshold) & block_mask & currently_masked & (predicted_tokens != mask_id)

            for i in range(batch_size):
                if current_pos[i] >= total_lengths[i]: continue
                start_idx, end_idx = current_pos[i], min(current_pos[i] + block_length, total_lengths[i].item())
                if not high_conf_indices[i, start_idx:end_idx].any():
                    valid_fallback_mask = block_mask[i] & currently_masked[i]
                    if not valid_fallback_mask.any(): continue
                    candidate_indices = torch.where(valid_fallback_mask)[0]
                    candidate_confs = predicted_confidences[i, candidate_indices]
                    candidate_tokens = predicted_tokens[i, candidate_indices]
                    sorted_confs, sort_indices = torch.sort(candidate_confs, descending=True)
                    best_idx_to_fill = -1
                    for sorted_idx in sort_indices:
                        if candidate_tokens[sorted_idx] != mask_id:
                            best_idx_to_fill = candidate_indices[sorted_idx]; break
                    if best_idx_to_fill != -1:
                        high_conf_indices[i, best_idx_to_fill] = True
                    else:
                        stuck_logits = logits[i, candidate_indices]
                        stuck_logits[:, mask_id] = -torch.inf
                        new_confidences = F.softmax(stuck_logits, dim=-1)
                        new_best_confs, new_best_tokens = torch.max(new_confidences, dim=-1)
                        best_of_the_best_local_idx = torch.argmax(new_best_confs)
                        pos_to_fill = candidate_indices[best_of_the_best_local_idx]
                        token_to_fill = new_best_tokens[best_of_the_best_local_idx]
                        predicted_tokens[i, pos_to_fill] = token_to_fill
                        high_conf_indices[i, pos_to_fill] = True

            potential_expand_mask = (predicted_confidences < low_conf_threshold) & block_mask & currently_masked & (~high_conf_indices)
            expand_indices = torch.zeros_like(x, dtype=torch.bool, device=device)
            for i in range(batch_size):
                if batch_eos_confidences[i] >= expand_eos_confidence_threshold or gen_lengths[i] >= max_gen_length: continue
                if denoise_only_mode[i] or current_pos[i] >= total_lengths[i]: continue
                masked_candidates = torch.where(potential_expand_mask[i])[0]
                if len(masked_candidates) > 0:
                    candidate_confs = predicted_confidences[i, masked_candidates]
                    num_to_expand = min(1, len(masked_candidates))
                    if num_to_expand > 0:
                        _, lowest_conf_local_indices = torch.topk(candidate_confs, num_to_expand, largest=False)
                        indices_to_expand_global = masked_candidates[lowest_conf_local_indices]
                        expand_indices[i, indices_to_expand_global] = True
            
            fill_mask = high_conf_indices
            if not expand_indices.any():
                x[fill_mask] = predicted_tokens[fill_mask]
            else:
                x[fill_mask] = predicted_tokens[fill_mask]
                max_new_total_len = 0
                temp_new_gen_lengths = gen_lengths.clone()
                for i in range(batch_size):
                    expansion_count = expand_indices[i].sum().item()
                    if expansion_count > 0:
                        new_len = gen_lengths[i].item() + expansion_count * (expansion_factor - 1)
                        temp_new_gen_lengths[i] = min(new_len, max_gen_length)
                max_new_total_len = prompt_length + temp_new_gen_lengths.max()
                
                new_x_tensor = torch.full((batch_size, max_new_total_len), eos_token_id, device=device, dtype=torch.long)
                new_gen_lengths = torch.zeros_like(gen_lengths)

                for i in range(batch_size):
                    if not expand_indices[i].any():
                        total_len = prompt_length + gen_lengths[i].item()
                        new_x_tensor[i, :total_len] = x[i, :total_len]
                        new_gen_lengths[i] = gen_lengths[i]
                        continue
                    write_ptr = prompt_length
                    new_x_tensor[i, :prompt_length] = x[i, :prompt_length]
                    for j in range(prompt_length, prompt_length + gen_lengths[i].item()):
                        if write_ptr >= max_new_total_len: break
                        if expand_indices[i, j]:
                            end_write = min(write_ptr + expansion_factor, max_new_total_len)
                            new_x_tensor[i, write_ptr:end_write] = mask_id
                            write_ptr = end_write
                        else:
                            new_x_tensor[i, write_ptr] = x[i, j]
                            write_ptr += 1
                    new_gen_lengths[i] = write_ptr - prompt_length
                x = new_x_tensor
                gen_lengths = new_gen_lengths
            for i in range(batch_size):
                total_len = prompt_length + gen_lengths[i]
                while current_pos[i] < total_len:
                    start_check = current_pos[i]
                    end_check = min(start_check + block_length, total_len.item())
                    if start_check == end_check: break
                    if not (x[i, start_check:end_check] == mask_id).any():
                        current_pos[i] = start_check + block_length
                    else:
                        break
            if torch.equal(x, x_before_step):
                if not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0:
                    print(f"WARNING: Sequence state is stagnant, forcing generation to end.")
                break
        final_outputs = []
        for i in range(batch_size):
            final_len = prompt_length + gen_lengths[i]
            final_outputs.append(x[i, :final_len])
        return final_outputs
