from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.distributions as dists
from torch.nn import functional as F
from transformers import __version__
from transformers.generation.configuration_utils import GenerationConfig
from transformers.utils import ModelOutput, logging


logger = logging.get_logger(__name__)


@dataclass
class DreamModelOutput(ModelOutput):
    sequences: torch.LongTensor = None
    history: Optional[Tuple[torch.FloatTensor]] = None


class DreamGenerationConfig(GenerationConfig):
    def __init__(self, **kwargs):
        self.temperature: float = kwargs.pop("temperature", 0.0)
        self.top_p: Optional[float] = kwargs.pop("top_p", None)
        self.top_k: Optional[int] = kwargs.pop("top_k", None)
        self.max_length = kwargs.pop("max_length", 20)
        self.max_new_tokens = kwargs.pop("max_new_tokens", None)
        self.eps: float = kwargs.pop("eps", 1e-3)
        self.steps: int = kwargs.pop("steps", 512)
        self.alg: str = kwargs.pop("alg", 'origin')
        self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
        self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
        self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
        self.output_history: bool = kwargs.pop("output_history", False)
        self.mask_token_id = kwargs.pop("mask_token_id", None)
        self.pad_token_id = kwargs.pop("pad_token_id", None)
        self.bos_token_id = kwargs.pop("bos_token_id", None)
        self.eos_token_id = kwargs.pop("eos_token_id", None)
        self.generation_kwargs = kwargs.pop("generation_kwargs", {})
        self._from_model_config = kwargs.pop("_from_model_config", False)
        self._commit_hash = kwargs.pop("_commit_hash", None)
        self.transformers_version = kwargs.pop("transformers_version", __version__)
        if not self._from_model_config:
            for key, value in kwargs.items():
                try:
                    setattr(self, key, value)
                except AttributeError as err:
                    logger.error(f"Can't set {key} with value {value} for {self}")
                    raise err
        self.validate(is_init=True)
    def validate(self, is_init=False):
        pass


def top_p_logits(logits, top_p=None):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0
    mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
    mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
    return logits


def top_k_logits(logits, top_k=None):
    top_k = min(top_k, logits.size(-1))
    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
    return logits


def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
    if temperature > 0:
        logits = logits / temperature
    if top_p is not None and top_p < 1:
        logits = top_p_logits(logits, top_p)
    if top_k is not None:
        logits = top_k_logits(logits, top_k)
    probs = torch.softmax(logits, dim=-1)
    if temperature > 0:
        try:
            x0 = dists.Categorical(probs=probs).sample()
            confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
        except:
            confidence, x0 = probs.max(dim=-1)
    else:
        confidence, x0 = probs.max(dim=-1)
    if margin_confidence:
        sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
        top1_probs = sorted_probs[:, 0] 
        top2_probs = sorted_probs[:, 1] 
        confidence = top1_probs - top2_probs 
    if neg_entropy:
        epsilon = 1e-10
        log_probs = torch.log(probs + epsilon)
        confidence = torch.sum(probs * log_probs, dim=-1)
    return confidence, x0


@torch.no_grad()
def diffusion_generate_daedal(
    self,
    inputs: Optional[torch.Tensor] = None,
    generation_config: Optional[DreamGenerationConfig] = None,
    **kwargs,
) -> Union[DreamModelOutput, torch.LongTensor]:

    def _calculate_eos_confidence(logits, total_lengths, prompt_length, eos_check_tokens, eos_token_id):
        if eos_token_id is None:
            return torch.zeros(logits.shape[0], device=logits.device)
        confidences = F.softmax(logits, dim=-1)
        predicted_tokens = torch.argmax(logits, dim=-1)
        batch_eos_confidences = []
        for i in range(logits.shape[0]):
            eos_confs_for_avg = []
            start_scan_pos = total_lengths[i].item() - 1
            end_scan_pos = prompt_length - 1
            for pos in range(start_scan_pos, end_scan_pos, -1):
                if len(eos_confs_for_avg) >= eos_check_tokens:
                    break
                if predicted_tokens[i, pos] == eos_token_id:
                    eos_confs_for_avg.append(confidences[i, pos, eos_token_id].item())
            avg_conf = sum(eos_confs_for_avg) / eos_check_tokens
            batch_eos_confidences.append(avg_conf)
        return torch.tensor(batch_eos_confidences, device=logits.device)

    generation_config = self._prepare_generation_config(generation_config, **kwargs)
    gen_kwargs = generation_config.generation_kwargs
    
    initial_gen_length = gen_kwargs.get("initial_gen_length", 64)
    max_gen_length = gen_kwargs.get("max_gen_length", 2048)
    block_length = gen_kwargs.get("block_length", 32)
    temperature = gen_kwargs.get("temperature", 0.0)
    cfg_scale = gen_kwargs.get("cfg_scale", 0.0)
    high_conf_threshold = gen_kwargs.get("high_conf_threshold", 0.90)
    low_conf_threshold = gen_kwargs.get("low_conf_threshold", 0.10)
    expansion_factor = gen_kwargs.get("expansion_factor", 8)
    eos_confidence_threshold = gen_kwargs.get("eos_confidence_threshold", 0.5)
    expand_eos_confidence_threshold = gen_kwargs.get("expand_eos_confidence_threshold", 0.9)
    eos_check_tokens = gen_kwargs.get("eos_check_tokens", 32)

    self._prepare_special_tokens(generation_config, device=inputs.device)
    mask_token_id = generation_config.mask_token_id
    eos_token_id = generation_config.eos_token_id
    
    prompt = inputs
    batch_size = prompt.shape[0]
    device = prompt.device
    prompt_length = prompt.shape[1]
    
    gen_lengths = torch.full((batch_size,), initial_gen_length, dtype=torch.long, device=device)
    x = torch.full(
        (batch_size, prompt_length + initial_gen_length),
        mask_token_id,
        dtype=torch.long,
        device=device,
    )
    x[:, :prompt_length] = prompt.clone()

    while True:
        total_lengths = prompt_length + gen_lengths
        max_len_pre = x.shape[1]
        arange_tensor_pre = torch.arange(max_len_pre, device=device).expand(batch_size, -1)
        attention_mask_pre = (arange_tensor_pre < total_lengths.unsqueeze(1))
        
        if torch.any(attention_mask_pre == 0.0):
            tok_idx_pre = attention_mask_pre.long().cumsum(-1) - 1
            tok_idx_pre.masked_fill_(attention_mask_pre == 0, 1)
            attention_mask_pre_4d = torch.logical_and(
                attention_mask_pre.unsqueeze(1).unsqueeze(-2),
                attention_mask_pre.unsqueeze(1).unsqueeze(-1),
            )
        else:
            tok_idx_pre = None
            attention_mask_pre_4d = "full"
        
        logits_pre = self(x, attention_mask=attention_mask_pre_4d, tok_idx=tok_idx_pre).logits
        logits_pre = torch.cat([logits_pre[:, :1], logits_pre[:, :-1]], dim=1)
        
        batch_eos_confidences = _calculate_eos_confidence(logits_pre, total_lengths, prompt_length, eos_check_tokens, eos_token_id)
        
        sequences_to_expand = (batch_eos_confidences < eos_confidence_threshold) & (gen_lengths < max_gen_length)
        
        if not sequences_to_expand.any():
            if not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0:
                print(f"All sequences' EOS confidence reach the threshold {eos_confidence_threshold} or max length.")
            break
        if not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0:
                print(f"Some sequences' EOS confidence ({[round(c.item(), 4) for c in batch_eos_confidences]}) < {eos_confidence_threshold}. Expand initial length.")

        new_gen_lengths = gen_lengths.clone()
        new_gen_lengths[sequences_to_expand] = torch.clamp(gen_lengths[sequences_to_expand] + expansion_factor, max=max_gen_length)
        
        if new_gen_lengths.max() <= gen_lengths.max():
            if not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0:
                print(f"WARNING: Cannot expand initial length further (already at max length: {max_gen_length}).")
            break

        max_new_total_len = prompt_length + new_gen_lengths.max()
        new_x_tensor = torch.full((batch_size, max_new_total_len), mask_token_id, dtype=torch.long, device=device)
        
        for i in range(batch_size):
            original_total_len = prompt_length + gen_lengths[i].item()
            new_x_tensor[i, :original_total_len] = x[i, :original_total_len]
        
        x = new_x_tensor
        gen_lengths = new_gen_lengths
    
    if not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0:
        print(f"[Stage-2] Iterative Denoising and Mask Insertion")

    current_pos = torch.full((batch_size,), prompt_length, dtype=torch.long, device=device)
    denoise_only_mode = torch.zeros(batch_size, dtype=torch.bool, device=device)
    
    while (current_pos < prompt_length + gen_lengths).any():
        x_before_step = x.clone()
        
        for i in range(batch_size):
            if gen_lengths[i] >= max_gen_length and not denoise_only_mode[i]:
                if not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0:
                        print(f"Sequence {i} has reached the max length {max_gen_length}. Entering denoise-only mode.")
                denoise_only_mode[i] = True

        total_lengths = prompt_length + gen_lengths
        max_len = x.shape[1]
        arange_tensor = torch.arange(max_len, device=device).expand(batch_size, -1)
        attention_mask = (arange_tensor < total_lengths.unsqueeze(1))
        
        if torch.any(attention_mask == 0.0):
            tok_idx = attention_mask.long().cumsum(-1) - 1
            tok_idx.masked_fill_(attention_mask == 0, 1)
            attention_mask_4d = torch.logical_and(
                attention_mask.unsqueeze(1).unsqueeze(-2),
                attention_mask.unsqueeze(1).unsqueeze(-1),
            )
        else:
            tok_idx = None
            attention_mask_4d = "full"
            
        logits = self(x, attention_mask=attention_mask_4d, tok_idx=tok_idx).logits
        logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)

        predicted_confidences, predicted_tokens = sample_tokens(logits, temperature=temperature)
        batch_eos_confidences = _calculate_eos_confidence(logits, total_lengths, prompt_length, eos_check_tokens, eos_token_id)
        
        block_mask = torch.zeros_like(x, dtype=torch.bool, device=device)
        for i in range(batch_size):
            if current_pos[i] >= total_lengths[i]: continue
            block_mask[i, current_pos[i]:min(current_pos[i] + block_length, total_lengths[i].item())] = True
        
        currently_masked = (x == mask_token_id)
        high_conf_indices = (predicted_confidences > high_conf_threshold) & block_mask & currently_masked & (predicted_tokens != mask_token_id)

        for i in range(batch_size):
            if current_pos[i] >= total_lengths[i]: continue
            start_idx, end_idx = current_pos[i], min(current_pos[i] + block_length, total_lengths[i].item())
            if not high_conf_indices[i, start_idx:end_idx].any():
                valid_fallback_mask = block_mask[i] & currently_masked[i]
                if not valid_fallback_mask.any(): continue
                candidate_indices = torch.where(valid_fallback_mask)[0]
                if len(candidate_indices) == 0: continue
                candidate_confs = predicted_confidences[i, candidate_indices]
                candidate_tokens = predicted_tokens[i, candidate_indices]
                sorted_confs, sort_indices = torch.sort(candidate_confs, descending=True)
                best_idx_to_fill = -1
                for sorted_idx in sort_indices:
                    if candidate_tokens[sorted_idx] != mask_token_id:
                        best_idx_to_fill = candidate_indices[sorted_idx]; break
                if best_idx_to_fill != -1:
                    high_conf_indices[i, best_idx_to_fill] = True
                else:
                    stuck_logits = logits[i, candidate_indices]
                    stuck_logits[:, mask_token_id] = -torch.inf
                    new_confidences = F.softmax(stuck_logits, dim=-1)
                    new_best_confs, new_best_tokens = torch.max(new_confidences, dim=-1)
                    best_of_the_best_local_idx = torch.argmax(new_best_confs)
                    pos_to_fill = candidate_indices[best_of_the_best_local_idx]
                    token_to_fill = new_best_tokens[best_of_the_best_local_idx]
                    predicted_tokens[i, pos_to_fill] = token_to_fill
                    high_conf_indices[i, pos_to_fill] = True

        potential_expand_mask = (predicted_confidences < low_conf_threshold) & block_mask & currently_masked & (~high_conf_indices)
        expand_indices = torch.zeros_like(x, dtype=torch.bool, device=device)
        for i in range(batch_size):
            if batch_eos_confidences[i] >= expand_eos_confidence_threshold or gen_lengths[i] >= max_gen_length: continue
            if denoise_only_mode[i] or current_pos[i] >= total_lengths[i]: continue
            masked_candidates = torch.where(potential_expand_mask[i])[0]
            if len(masked_candidates) > 0:
                candidate_confs = predicted_confidences[i, masked_candidates]
                num_to_expand = min(1, len(masked_candidates))
                if num_to_expand > 0:
                    _, lowest_conf_local_indices = torch.topk(candidate_confs, num_to_expand, largest=False)
                    indices_to_expand_global = masked_candidates[lowest_conf_local_indices]
                    expand_indices[i, indices_to_expand_global] = True
        
        fill_mask = high_conf_indices
        if not expand_indices.any():
            x[fill_mask] = predicted_tokens[fill_mask]
        else:
            x[fill_mask] = predicted_tokens[fill_mask]
            temp_new_gen_lengths = gen_lengths.clone()
            for i in range(batch_size):
                expansion_count = expand_indices[i].sum().item()
                if expansion_count > 0:
                    new_len = gen_lengths[i].item() + expansion_count * (expansion_factor - 1)
                    temp_new_gen_lengths[i] = min(new_len, max_gen_length)
            max_new_total_len = prompt_length + temp_new_gen_lengths.max()
            
            new_x_tensor = torch.full((batch_size, max_new_total_len), mask_token_id, device=device, dtype=torch.long)
            new_gen_lengths = torch.zeros_like(gen_lengths)

            for i in range(batch_size):
                if not expand_indices[i].any():
                    total_len = prompt_length + gen_lengths[i].item()
                    new_x_tensor[i, :total_len] = x[i, :total_len]
                    new_gen_lengths[i] = gen_lengths[i]
                    continue
                write_ptr = prompt_length
                new_x_tensor[i, :prompt_length] = x[i, :prompt_length]
                for j in range(prompt_length, prompt_length + gen_lengths[i].item()):
                    if write_ptr >= max_new_total_len: break
                    if expand_indices[i, j]:
                        end_write = min(write_ptr + expansion_factor, max_new_total_len)
                        new_x_tensor[i, write_ptr:end_write] = mask_token_id
                        write_ptr = end_write
                    else:
                        new_x_tensor[i, write_ptr] = x[i, j]
                        write_ptr += 1
                new_gen_lengths[i] = write_ptr - prompt_length
            x = new_x_tensor
            gen_lengths = new_gen_lengths

        for i in range(batch_size):
            total_len = prompt_length + gen_lengths[i]
            while current_pos[i] < total_len:
                start_check = current_pos[i]
                end_check = min(start_check + block_length, total_len.item())
                if start_check == end_check: break
                if not (x[i, start_check:end_check] == mask_token_id).any():
                    current_pos[i] = start_check + block_length
                else:
                    break
        if torch.equal(x, x_before_step):
            if not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0:
                print(f"WARNING: Sequence state is stagnant, forcing generation to end.")
            break
    final_outputs_list = []
    for i in range(batch_size):
        final_len = prompt_length + gen_lengths[i]
        final_outputs_list.append(x[i, :final_len])

    return final_outputs_list