# coding=utf-8
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
import copy
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import torch
import torch.distributions as dists
from torch.nn import functional as F
from transformers import __version__
from transformers.generation.configuration_utils import (
    GenerationConfig
)
from transformers.utils import (
    ModelOutput,
    is_torchdynamo_compiling,
    logging,
)
import matplotlib.pyplot as plt
import seaborn as sns
import os
import torch 
import numpy as np 


def plot_confidence_heatmap(
    confidence_tensor_item,  # 单个batch item的置信度张量 (gen_length,)
    x_tensor_item_gen_part,  # 单个batch item的x张量的生成部分 (gen_length,)
    mask_id,
    title_info,              # 包含 block, sub-cycle, phase, step 等信息的字典
    save_dir="confidence_plots",
    gen_length=128 # 需要知道gen_length以正确显示x轴
):
    """绘制并保存单个batch item的置信度热力图。"""
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)

    conf_np = confidence_tensor_item.to(torch.float32).cpu().numpy()
    x_np_gen = x_tensor_item_gen_part.cpu().numpy()

    # 将 -np.inf 替换为 NaN 以便绘图，但保留原始值用于注释
    conf_for_heatmap = conf_np.copy()
    conf_for_heatmap[np.isinf(conf_for_heatmap)] = np.nan

    # 创建注释文本：显示置信度，如果不是mask则显示token ID
    annot_text = np.full_like(conf_np, "", dtype=object).reshape(1, -1)
    for i in range(gen_length):
        if x_np_gen[i] != mask_id:
            annot_text[0, i] = f"T:{x_np_gen[i]}" # 已填充的token
        elif not np.isnan(conf_for_heatmap[i]): # 是mask且有置信度
             annot_text[0, i] = f"{conf_np[i]:.2f}"
        # else: # 是mask且置信度为-inf (或NaN)
        #    annot_text[0, i] = "Mask" # 或保持为空


    plt.figure(figsize=(max(15, gen_length // 5), 5))
    sns.heatmap(
        conf_for_heatmap.reshape(1, -1),
        annot=annot_text, # 使用自定义注释
        fmt="s", # 使用字符串格式的注释 (因为annot_text是字符串)
        cmap="viridis",
        cbar=True,
        xticklabels=gen_length // 10 if gen_length > 20 else True,
        yticklabels=False,
        linewidths=.5,
        linecolor='gray',
        # vmin=0, vmax=1 # 如果置信度是概率的话
    )
    
    title = (f"Conf: Blk {title_info['block_idx']}, SubCyc {title_info['sub_cycle']}, "
             f"Ph {title_info['phase']}, Step {title_info['step_in_phase']} "
             f"(TotalCalls {title_info['total_calls']})")
    plt.title(title, fontsize=10)
    plt.xlabel("Token Position in Generation")
    plt.tight_layout()

    filename = (f"conf_blk{title_info['block_idx']}_sc{title_info['sub_cycle']}_"
                f"ph{title_info['phase']}_st{title_info['step_in_phase']}_"
                f"tc{title_info['total_calls']}_endindex{title_info['end_index']}.png")
    filepath = os.path.join(save_dir, filename)
    plt.savefig(filepath, dpi=150)
    plt.close()

logger = logging.get_logger(__name__)


def top_p_logits(logits, top_p=None):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = cumulative_probs > top_p
    # Shift the indices to the right to keep the first token above the threshold
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
    mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
    return logits

def top_k_logits(logits, top_k=None):
    top_k = min(top_k, logits.size(-1))  # Safety check
    # Remove all tokens with a probability less than the last token of the top-k
    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
    return logits


def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):

    if temperature > 0:
        logits = logits / temperature
    if top_p is not None and top_p < 1:
        logits = top_p_logits(logits, top_p)
    if top_k is not None:
        logits = top_k_logits(logits, top_k)
    probs = torch.softmax(logits, dim=-1)

    if temperature > 0:
        try:
            x0 = dists.Categorical(probs=probs).sample()
            confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
        except:
            confidence, x0 = probs.max(dim=-1)
    else:
        confidence, x0 = probs.max(dim=-1)
    
    if margin_confidence:
        sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
        # Extract top1 and top2 probabilities
        top1_probs = sorted_probs[:, 0] 
        top2_probs = sorted_probs[:, 1] 
        # Calculate confidence as top1 - top2
        confidence = top1_probs - top2_probs 
    
    if neg_entropy:
        epsilon = 1e-10
        log_probs = torch.log(probs + epsilon)
        confidence = torch.sum(probs * log_probs, dim=-1)
    
    return confidence, x0

@torch.no_grad()
def diffusion_generate(
    self,
    inputs: Optional[torch.Tensor] = None,
    generation_config = None,
    **kwargs,
):
    # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
    generation_config = self._prepare_generation_config(generation_config, **kwargs)
    generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
    generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)

    # 2. Define model inputs
    assert inputs is not None
    input_ids = inputs
    device = input_ids.device
    attention_mask = kwargs.pop("attention_mask", None)
    self._prepare_special_tokens(generation_config, device=device)

    # 3. Prepare `max_length`.
    input_ids_length = input_ids.shape[-1]
    has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
    generation_config = self._prepare_generated_length(
        generation_config=generation_config,
        has_default_max_length=has_default_max_length,
        input_ids_length=input_ids_length,
    )

    self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
    
    # 4. Check input_ids
    if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
        warnings.warn(
            "You are calling .generate() with the `input_ids` being on a device type different"
            f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
            f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
            " Please make sure that you have put `input_ids` to the"
            f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
            " running `.generate()`.",
            UserWarning,
        )
    if (
        hasattr(generation_config, "pad_token_id") and
        torch.any(input_ids == generation_config.pad_token_id) and 
        attention_mask is None
    ):
        warnings.warn(
            "Padding was detected but no attention mask is passed here. For correct "
            "generation results, please set `attention_mask` when batch-padding inputs.",
            UserWarning,
        )

    input_ids, attention_mask = self._expand_inputs_for_generation(
        expand_size=generation_config.num_return_sequences,
        input_ids=input_ids,
        attention_mask=attention_mask 
    )
    avg_model_calls_length = 0
    global_model_calls = 0
    
    # result = self._sample(
    #     input_ids,
    #     attention_mask=attention_mask,
    #     generation_config=generation_config,
    #     generation_tokens_hook_func=generation_tokens_hook_func,
    #     generation_logits_hook_func=generation_logits_hook_func
    # )
    result, avg_model_calls_length, global_model_calls = slow_fast_sample(
        self,
        input_ids,
        attention_mask=attention_mask,
        generation_config=generation_config,
        generation_tokens_hook_func=generation_tokens_hook_func,
        generation_logits_hook_func=generation_logits_hook_func
    )
    return result, avg_model_calls_length, global_model_calls


def slow_fast_sample(
    self,
    input_ids: torch.LongTensor,
    attention_mask: Optional[torch.LongTensor], # This is the original attention mask
    generation_config,
    generation_tokens_hook_func,
    generation_logits_hook_func
) :
    plot_confidence_maps = False
    plot_save_dir = 'confidence_dream_map'
    total_model_calls_length = 0

    output_history = generation_config.output_history
    return_dict_in_generate = generation_config.return_dict_in_generate
    max_length = generation_config.max_length
    mask_token_id = generation_config.mask_token_id
    
    # max_iterations is the total budget for the entire generation process
    max_total_iterations = generation_config.steps 
    temperature = generation_config.temperature
    top_p = generation_config.top_p
    top_k = generation_config.top_k

    N_CYCLE_DISCOVERY_STEPS = getattr(generation_config, "n_cycle_discovery_steps", 6) 
    CONF_THRESHOLD_CYCLE_LEN_PRED = getattr(generation_config, "conf_threshold_cycle_len_pred", 0.1)
    STABILIZE_COUNT = getattr(generation_config, "stabilize_count", 2)
    CYCLE_LEN_VARIANCE_THRESHOLD = getattr(generation_config, "cycle_len_variance_threshold", 1.0)
    HIGH_CONF_FOR_PARALLEL_DECODE_DISCOVERY = getattr(generation_config, "high_conf_for_parallel_decode_discovery", 0.8)
    MIN_TOKENS_FOR_PARALLEL_DECODE_DISCOVERY = getattr(generation_config, "min_tokens_for_parallel_decode_discovery", 2)

    HIGH_CONF_FOR_CYCLE_DECODE = getattr(generation_config, "high_conf_for_cycle_decode", 0.85)
    MIN_DECODE_IN_CYCLE_FALLBACK = getattr(generation_config, "min_decode_in_cycle_fallback", 2)
    
    histories = [] if (return_dict_in_generate and output_history) else None

    x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
    batch_size = x.shape[0]
    device = x.device
    prompt_length = input_ids.shape[1]
    gen_length = max_length - input_ids.shape[1]

    if attention_mask is not None and torch.any(attention_mask == 0.0):
        attention_mask_padded = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
        tok_idx = attention_mask_padded.long().cumsum(-1) - 1
        tok_idx.masked_fill_(attention_mask_padded == 0, 1)
        attention_mask_for_model = torch.logical_and(
            attention_mask_padded.unsqueeze(1).unsqueeze(-2),
            attention_mask_padded.unsqueeze(1).unsqueeze(-1),
        )
    else:
        tok_idx = None
        attention_mask_for_model = "full"

    x = generation_tokens_hook_func(None, x, None) # Initial hook

    stable_cycle_length_batch = torch.full((batch_size, max_length), 0, device=device, dtype=torch.long) 
    is_cycle_len_determined_batch = torch.zeros(batch_size, device=device, dtype=torch.bool) # Reset for each major cycle's phase 0
    potential_cycle_lengths_history_tensor = torch.zeros((STABILIZE_COUNT, batch_size), device=device, dtype=torch.long)
    
    global_model_calls = 0 

    # ============================================================
    #  OUTER MAJOR LOOP
    # ============================================================
    major_cycle_iter = -1
    while True:
        major_cycle_iter += 1
        if not (x == mask_token_id).any():
            break
        # Reset determination status for phase 0 of this major cycle
        is_cycle_len_determined_batch.fill_(False)
        potential_cycle_lengths_history_tensor.fill_(0)

        # ============================================================
        #  PHASE 0: CYCLE DISCOVERY (within major_cycle_iter)
        # ============================================================
        for discovery_step in range(N_CYCLE_DISCOVERY_STEPS):
            if not (x == mask_token_id).any(): 
                break
            
            mask_positions_bool_p0 = (x == mask_token_id)
            
            model_outputs_p0 = self(x, attention_mask_for_model, tok_idx)
            logits_p0 = model_outputs_p0.logits
            logits_p0 = torch.cat([logits_p0[:,:1], logits_p0[:, :-1]], dim=1)
            logits_p0 = generation_logits_hook_func(global_model_calls, x, logits_p0)
            
            total_model_calls_length += gen_length

            x_updated_p0 = x.clone()
            confidence = None
            x0 = None

            for b_idx in range(batch_size):
                probs = torch.softmax(logits_p0[b_idx, prompt_length:], dim=-1)
                confidence, x0 = probs.max(dim=-1)

                indices_to_decode_p0 = torch.tensor([], dtype=torch.long, device=device)

                start_abs_idx = prompt_length + stable_cycle_length_batch[b_idx, major_cycle_iter-1]
                start_cycle_idx = stable_cycle_length_batch[b_idx, major_cycle_iter-1]
                
                confidence_undecoded_cycle = torch.where(mask_positions_bool_p0[b_idx, start_abs_idx:], confidence[start_cycle_idx:], torch.tensor(-np.inf, device=x.device, dtype=confidence.dtype))
                if not is_cycle_len_determined_batch[b_idx]:
                    above_thresh_indices_in_scope = (confidence_undecoded_cycle >= CONF_THRESHOLD_CYCLE_LEN_PRED).nonzero(as_tuple=True)[0]

                    if len(above_thresh_indices_in_scope) > 0:
                        farthest_idx_in_scope = above_thresh_indices_in_scope.max().item()
                        potential_len = start_cycle_idx + farthest_idx_in_scope + 1 
                    else:
                        potential_len = start_cycle_idx + 1 
                    current_history_slot = discovery_step % STABILIZE_COUNT
                    potential_cycle_lengths_history_tensor[current_history_slot, b_idx] = potential_len
                    
                    if discovery_step >= 0:
                        history_window_for_item_list = [
                            potential_cycle_lengths_history_tensor[i, b_idx].item() for i in range(STABILIZE_COUNT)
                        ]
                        history_window_tensor = torch.tensor(history_window_for_item_list, 
                                                            device=x.device, dtype=torch.float32)
                        variance_of_lengths = torch.tensor(float('inf'), device=x.device) 
                        if STABILIZE_COUNT >= 2: 
                            variance_of_lengths = torch.var(history_window_tensor, unbiased=True)
                        elif STABILIZE_COUNT == 1: 
                            variance_of_lengths = torch.tensor(0.0, device=x.device)
                        is_variance_low = variance_of_lengths <= CYCLE_LEN_VARIANCE_THRESHOLD

                        if is_variance_low:
                            stable_cycle_length_batch[b_idx, major_cycle_iter] = potential_len
                            is_cycle_len_determined_batch[b_idx] = True
                        else:
                            mean_len = torch.mean(history_window_tensor)
                            stable_cycle_length_batch[b_idx, major_cycle_iter] = torch.round(mean_len).long().clamp(min=1) # 四舍五入并确保至少为1
                            is_cycle_len_determined_batch[b_idx] = False if discovery_step < N_CYCLE_DISCOVERY_STEPS -1 else True# 明确标记为不稳定
                            
                # decode
                num_high_conf = (confidence_undecoded_cycle >= HIGH_CONF_FOR_PARALLEL_DECODE_DISCOVERY).sum().item()
                if num_high_conf >= MIN_TOKENS_FOR_PARALLEL_DECODE_DISCOVERY:
                    indices_to_decode_p0 = torch.where(confidence_undecoded_cycle >= HIGH_CONF_FOR_PARALLEL_DECODE_DISCOVERY)[0]
                elif len(confidence_undecoded_cycle) > 0:
                    _, indices_to_decode_p0 = torch.topk(confidence_undecoded_cycle, k=1)
                
                if len(indices_to_decode_p0) > 0:
                    orig_indices = indices_to_decode_p0 + start_abs_idx
                    tokens_fill = x0[indices_to_decode_p0 + start_cycle_idx]
                    x_updated_p0[b_idx, orig_indices] = tokens_fill
                    
            
            x = x_updated_p0
            x = generation_tokens_hook_func(global_model_calls, x, logits_p0)
            if histories is not None: histories.append(x.clone())
            global_model_calls += 1

            
            if plot_confidence_maps:
                    title_info_p2 = {
                        "block_idx":0, "sub_cycle": major_cycle_iter,
                        "phase": "1", "step_in_phase": N_CYCLE_DISCOVERY_STEPS,
                        "total_calls": global_model_calls,
                        "end_index":-1
                    }
                    confidence_gen_wide = torch.where(mask_positions_bool_p0[b_idx, prompt_length:], confidence, torch.tensor(-np.inf, device=x.device, dtype=confidence.dtype))
                    plot_confidence_heatmap(
                        confidence_gen_wide,
                        x[0, prompt_length:],
                        151666, title_info_p2, plot_save_dir, max_length - prompt_length
                    )

        # ============================================================
        #  PHASE 1: CYCLE DECODE (within major_cycle_iter)
        # ============================================================

        phase1_steps_this_major_cycle = 0
        NO_PROGRESS_PATIENCE_P1 = 3 
        no_progress_counter_p1 = torch.zeros(batch_size, dtype=torch.int, device=device)
        cache_for_out_cycle = None
        is_finish = False

        while True: # Loop for Phase 1 steps within this major_cycle
            start_cycle_idx = stable_cycle_length_batch[b_idx, major_cycle_iter-1]
            end_cycle_idx = stable_cycle_length_batch[b_idx, major_cycle_iter]
            start_abs_idx = stable_cycle_length_batch[b_idx, major_cycle_iter-1] + prompt_length
            end_abs_idx = stable_cycle_length_batch[b_idx, major_cycle_iter] + prompt_length

            mask_positions_bool_p1 = (x == mask_token_id)
            cycle_mask_p1 = mask_positions_bool_p1[b_idx, start_abs_idx:end_abs_idx]

            if not cycle_mask_p1.any(): 
                is_finish = True
                break

            any_progress_this_p1_step = False # Flag to see if any token was decoded in this p1 step across batch
            if phase1_steps_this_major_cycle == 0:
                logits_p1 = self(x, attention_mask_for_model, tok_idx).logits
                cache_for_out_cycle = logits_p1[:, end_abs_idx:]
                total_model_calls_length += gen_length
            else:
                current_tok_idx_p1 = tok_idx[:, :end_abs_idx] if tok_idx is not None else None
                current_attention_mask_p1 = attention_mask_for_model[:, :, :end_abs_idx, :end_abs_idx] \
                                        if attention_mask_for_model != "full" and attention_mask_for_model is not None \
                                        else attention_mask_for_model 
                
                logits_p1 = self(x[:, :end_abs_idx], current_attention_mask_p1, current_tok_idx_p1).logits
                
                logits_p1 = torch.cat([logits_p1, cache_for_out_cycle], dim=1)
                total_model_calls_length += end_cycle_idx
            
            # logits_p1 = model_outputs_p1.logits
            logits_p1 = torch.cat([logits_p1[:,:1], logits_p1[:, :-1]], dim=1)
            logits_p1 = generation_logits_hook_func(global_model_calls, x, logits_p1)

            x_updated_p1 = x.clone()
            confidence_p1 = None
            x0_p1 = None

            for b_idx in range(batch_size):
                if (x[b_idx] == mask_token_id).sum() == 0:
                    no_progress_counter_p1[b_idx] = 0 # Reset patience if done
                    continue
                
                probs_p1 = torch.softmax(logits_p1[b_idx, prompt_length:], dim=-1)
                confidence_p1, x0_p1 = probs_p1.max(dim=-1)

                confidence_undecoded_in_cycle = torch.where(mask_positions_bool_p1[b_idx, start_abs_idx:end_abs_idx], confidence_p1[start_cycle_idx:end_cycle_idx], torch.tensor(-np.inf, device=x.device, dtype=confidence_p1.dtype))
                if start_abs_idx < end_abs_idx:
                    very_high_conf_idx = (confidence_undecoded_in_cycle >= HIGH_CONF_FOR_CYCLE_DECODE).nonzero(as_tuple=True)[0]

                    if len(very_high_conf_idx) < MIN_DECODE_IN_CYCLE_FALLBACK:
                        _, indices_to_decode_p1 = torch.topk(confidence_undecoded_in_cycle, k=1)
                    else:
                        indices_to_decode_p1 = very_high_conf_idx 

                
                if len(indices_to_decode_p1) > 0:
                    orig_indices = indices_to_decode_p1 + start_abs_idx
                    tokens_fill = x0_p1[indices_to_decode_p1 + start_cycle_idx]
                    x_updated_p1[b_idx, orig_indices] = tokens_fill
                    any_progress_this_p1_step = True
                    no_progress_counter_p1[b_idx] = 0 # Progress made for this item
                else:
                    no_progress_counter_p1[b_idx] +=1

            x = x_updated_p1
            
            x = generation_tokens_hook_func(global_model_calls, x, logits_p1)
            if histories is not None: histories.append(x.clone())
            global_model_calls += 1
            phase1_steps_this_major_cycle +=1

        
            if plot_confidence_maps:
                    title_info_p2 = {
                        "block_idx":0, "sub_cycle": major_cycle_iter,
                        "phase": "2&3_HighConf&fix", "step_in_phase": phase1_steps_this_major_cycle,
                        "total_calls": global_model_calls,
                        "end_index":end_cycle_idx
                    }
                    confidence_gen_wide = torch.where(mask_positions_bool_p1[b_idx, prompt_length:], confidence_p1, torch.tensor(-np.inf, device=x.device, dtype=confidence_p1.dtype))
                    plot_confidence_heatmap(
                        confidence_gen_wide,
                        x[0, prompt_length:],
                        151666, title_info_p2, plot_save_dir, max_length - prompt_length
                    )

    
    # --- End of OUTER MAJOR LOOP ---
    avg_model_calls_length = total_model_calls_length / global_model_calls

    final_num_masks_left = (x == mask_token_id).sum()
    if final_num_masks_left > 0:
            if global_model_calls >= max_total_iterations:
                logger.warning(f"达到总模型调用上限 ({max_total_iterations}), 但仍有 {final_num_masks_left.item()} MASK token 未填充.")
            else:
                logger.warning(f"大循环结束, 但仍有 {final_num_masks_left.item()} MASK token 未填充.")

    if return_dict_in_generate:
        return DreamModelOutput(sequences=x, history=histories), avg_model_calls_length, global_model_calls
    else:
        return x, avg_model_calls_length, global_model_calls


@dataclass
class DreamModelOutput(ModelOutput):
    sequences: torch.LongTensor = None
    history: Optional[Tuple[torch.FloatTensor]] = None

