import torch
import torch.nn.functional as F
from typing import Tuple, Optional


########################
# LLaDA (mask diffusion)
########################

def _add_gumbel_noise(logits: torch.Tensor, temperature: float) -> torch.Tensor:
    """
    Gumbel-max style perturbation for categorical sampling.

    Notes from LLaDA (arXiv:2409.02908): low-precision Gumbel Max degrades quality.
    We therefore compute in float64 and return a reweighted distribution.
    """
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (-torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def _get_num_transfer_tokens(mask_index: torch.Tensor, steps: int) -> torch.Tensor:
    """
    For LLaDA's linear schedule, precompute how many tokens are transitioned per step.
    mask_index: (B, L) bool tensor marking masked tokens in the current block.
    returns: (B, steps) ints per step per batch item.
    """
    mask_num = mask_index.sum(dim=1, keepdim=True)

    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = (
        torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64)
        + base
    )

    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, : remainder[i]] += 1

    return num_transfer_tokens


@torch.no_grad()
def llada_generate(
    model,
    tokenizer,
    input_text: str,
    device: str,
    *,
    steps: int = 128,
    max_new_tokens: int = 128,
    block_length: int = 32,
    temperature: float = 0.0,
    cfg_scale: float = 0.0,
    remasking: str = "low_confidence",
    mask_id: int = 126336,
) -> Tuple[str, torch.Tensor]:
    """
    LLaDA original diffusion-style generation.

    Returns the generated string (not including the prompt) and the input_ids used.
    """
    # Tokenize the already-formatted chat string
    inputs = tokenizer([input_text], return_tensors="pt")
    input_ids = inputs.input_ids.to(device)

    # Build initial sequence with masks for the continuation part
    gen_length = max_new_tokens
    x = torch.full((1, input_ids.shape[1] + gen_length), mask_id, dtype=torch.long, device=device)
    x[:, : input_ids.shape[1]] = input_ids.clone()

    prompt_index = x != mask_id

    assert gen_length % block_length == 0, "gen_length must be divisible by block_length"
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0, "steps must be divisible by num_blocks"
    steps_per_block = steps // num_blocks

    for num_block in range(num_blocks):
        s = input_ids.shape[1] + num_block * block_length
        e = input_ids.shape[1] + (num_block + 1) * block_length
        block_mask_index = x[:, s:e] == mask_id
        num_transfer_tokens = _get_num_transfer_tokens(block_mask_index, steps_per_block)

        for i in range(steps_per_block):
            mask_index = x == mask_id

            if cfg_scale > 0.0:
                un_x = x.clone()
                un_x[prompt_index] = mask_id
                x_ = torch.cat([x, un_x], dim=0)
                logits = model.forward(x_).logits
                logits, un_logits = torch.chunk(logits, 2, dim=0)
                logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
            else:
                logits = model.forward(x).logits

            logits_with_noise = _add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1)  # (B, L)

            if remasking == "low_confidence":
                p = torch.softmax(logits, dim=-1)
                x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)  # (B, L)
            elif remasking == "random":
                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
            else:
                raise NotImplementedError(remasking)

            # prevent transition outside current block (future blocks remain -inf)
            x0_p[:, input_ids.shape[1] + (num_block + 1) * block_length :] = -float("inf")

            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -float("inf"))

            transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            for j in range(confidence.shape[0]):
                _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
                transfer_index[j, select_index] = True
            x[transfer_index] = x0[transfer_index]

    # Decode only the continuation part
    gen_only = x[:, input_ids.shape[1] :]
    text = tokenizer.batch_decode(gen_only, skip_special_tokens=True)[0]
    return text, input_ids


############
# DREAM DLLM
############

@torch.no_grad()
def dream_generate(
    model,
    tokenizer,
    input_text: str,
    device: str,
    *,
    steps: int = 512,
    max_new_tokens: int = 512,
    temperature: float = 0.2,
    top_p: float = 0.95,
    alg: str = "entropy",
    alg_temp: float = 0.0,
) -> Tuple[str, torch.Tensor]:
    """
    DREAM diffusion generation using model.diffusion_generate.

    Returns the generated string (not including the prompt) and the input_ids used.
    """
    encoded = tokenizer([input_text], return_tensors="pt")
    input_ids = encoded.input_ids.to(device)
    attention_mask = encoded.attention_mask.to(device) if hasattr(encoded, "attention_mask") else None

    # Filter out invalid parameters that cause warnings
    valid_params = {
        'max_new_tokens': max_new_tokens,
        'output_history': False,
        'return_dict_in_generate': True,
        'steps': steps,
        'alg': alg,
        'alg_temp': alg_temp,
    }
    
    output = model.diffusion_generate(
        input_ids,
        attention_mask=attention_mask,
        **valid_params,
    )

    sequences = output.sequences  # (B, L_total)
    # Strip the prompt to return only the generated portion
    generations = [
        tokenizer.decode(g[len(p) :].tolist(), skip_special_tokens=True) for p, g in zip(input_ids, sequences)
    ]
    text = generations[0]
    # Truncate at first eos token if tokenizer.eos_token is defined
    if getattr(tokenizer, "eos_token", None) and tokenizer.eos_token in text:
        text = text.split(tokenizer.eos_token)[0]
    return text, input_ids


########################
# Public entry function
########################

def dllm_original_generation(
    *,
    dllm_type: str,
    model,
    tokenizer,
    device: str,
    input_text: str,
    **gen_kwargs,
) -> Tuple[str, torch.Tensor]:
    """
    Unified entry for original (non-latent-seek) diffusion LMs.

    Args:
        dllm_type: "llada" or "dream"
        model, tokenizer, device: loaded Hugging Face model/tokenizer and device string
        input_text: pre-formatted chat string (as produced by data.get_dataset)
        gen_kwargs: per-backend generation params

    Returns:
        (generated_text, input_ids)
    """
    dllm_type = dllm_type.lower()
    if dllm_type == "llada":
        return llada_generate(model, tokenizer, input_text, device, **gen_kwargs)
    elif dllm_type == "dream":
        return dream_generate(model, tokenizer, input_text, device, **gen_kwargs)
    else:
        raise ValueError(f"Unsupported dllm_type: {dllm_type}")

