"""
Reimplementation of
https://github.com/ML-GSAI/LLaDA/blob/main/generate.py
"""

import os
from dataclasses import dataclass
from typing import Literal

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer

from bd_mcts.demask_interface.base import DemaskStepper, Remasker

AlgType = Literal["random", "low_confidence"]

_ALLOW_MASK_TOKEN = os.environ.get("BD_MCTS_DEMASK_ALLOW_MASK_TOKEN", "").strip().lower() in (
    "1",
    "true",
    "yes",
    "y",
    "on",
)
_DEBUG_DEMASK = os.environ.get("BD_MCTS_DEBUG_DEMASK", "").strip().lower() in (
    "1",
    "true",
    "yes",
    "y",
    "on",
)


@dataclass
class LladaGenerationConfig:
    # number of diffusion steps
    steps: int = 512
    gen_length: int = 512
    block_length: int = -1

    alg: AlgType = "random"

    # sampling
    temperature: float = 0.0
    cfg_scale: float = 0.0
    max_new_tokens: int | None = None

    eos_penalty: float = 0.0  # > 0
    eot_confidence_penalty: float = 0.0  # > 0

    mask_token_id: int = -1
    pad_token_id: int = -1
    bos_token_id: int = -1
    eos_token_id: int = -1
    eot_token_ids: tuple[int, ...] = (126081, 126348)

    def __post_init__(self):
        assert self.mask_token_id >= 0
        assert self.pad_token_id >= 0
        assert self.bos_token_id >= 0
        assert self.eos_token_id >= 0

        if self.block_length == -1:
            self.block_length = self.gen_length

        assert self.gen_length % self.block_length == 0
        assert self.steps % self.num_blocks == 0

    @property
    def num_blocks(self) -> int:
        return self.gen_length // self.block_length


def add_gumbel_noise(logits: Tensor, temperature: float) -> Tensor:
    """
    The Gumbel max is a method for sampling categorical distributions.
    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
    Thus, we use float64.
    """
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (-torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def get_num_transfer_tokens(mask_index: Tensor, steps: int) -> Tensor:
    """
    In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
    Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
    the expected number of tokens transitioned at each step should be consistent.

    This function is designed to precompute the number of tokens that need to be transitioned at each step.
    """
    mask_num = mask_index.sum(dim=1, keepdim=True)

    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = (
        torch.zeros(
            mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64
        )
        + base
    )

    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, : remainder[i]] += 1

    return num_transfer_tokens


class Llada(DemaskStepper):
    def __init__(self, model: PreTrainedModel, **kwargs) -> None:
        self.model = model
        self.generation_config = LladaGenerationConfig(**kwargs)
        self._active_x_key: tuple[int, int, int] | None = None
        self._transfer_tokens_cache: dict[int, Tensor] = {}

    @torch.no_grad()
    def predict_x0(
        self, x: Tensor, step_idx: int, attention_mask: Tensor | None = None
    ) -> tuple[Tensor, Remasker]:
        generation_config = self.generation_config
        mask_id = generation_config.mask_token_id

        x_key = (x.data_ptr(), x.shape[0], x.shape[1])
        if step_idx == 0 or self._active_x_key != x_key:
            self._active_x_key = x_key
            self._transfer_tokens_cache = {}

        if step_idx < 0 or step_idx >= generation_config.steps:
            raise ValueError(
                f"step_idx must be in [0, {generation_config.steps}), got {step_idx}"
            )

        steps_per_block = generation_config.steps // generation_config.num_blocks
        block_idx = step_idx // steps_per_block
        step_in_block = step_idx % steps_per_block

        input_ids_length = x.shape[1] - generation_config.gen_length
        if input_ids_length < 0:
            raise ValueError(
                "x length must be >= gen_length to compute input_ids_length"
            )

        if attention_mask is not None:
            if attention_mask.shape[1] > x.shape[1]:
                raise ValueError(
                    "attention_mask length must be <= x length when provided"
                )
            if attention_mask.shape[1] < x.shape[1]:
                pad_len = x.shape[1] - attention_mask.shape[1]
                attention_mask = torch.cat(
                    [
                        attention_mask,
                        torch.ones(
                            (attention_mask.shape[0], pad_len),
                            dtype=attention_mask.dtype,
                            device=attention_mask.device,
                        ),
                    ],
                    dim=-1,
                )

        block_start = input_ids_length + block_idx * generation_config.block_length
        block_end = block_start + generation_config.block_length
        if block_end > x.shape[1]:
            raise ValueError("block range exceeds x length")

        block_mask_index: Tensor = x[:, block_start:block_end] == mask_id
        num_transfer_tokens = self._transfer_tokens_cache.get(block_idx)
        if num_transfer_tokens is None:
            num_transfer_tokens = get_num_transfer_tokens(
                block_mask_index, steps_per_block
            )
            self._transfer_tokens_cache[block_idx] = num_transfer_tokens

        mask_index = x == mask_id
        if generation_config.cfg_scale > 0.0:
            raise RuntimeError("cfg_scale > 0 is not supported")

        logits = self.model(x, attention_mask=attention_mask).logits
        if generation_config.eos_penalty > 0:
            logits[:, :, generation_config.eos_token_id] -= (
                generation_config.eos_penalty
            )
        if not _ALLOW_MASK_TOKEN:
            # Disallow predicting the mask token itself at masked positions.
            # If the model predicts `mask_token_id`, transferring that token does not reduce the
            # mask count, and the precomputed transfer schedule can finish with leftover masks.
            if 0 <= mask_id < logits.shape[-1]:
                logits[:, :, mask_id] = torch.finfo(logits.dtype).min

        logits_with_noise = add_gumbel_noise(
            logits, temperature=generation_config.temperature
        )
        x0 = torch.argmax(logits_with_noise, dim=-1)

        if generation_config.eot_confidence_penalty > 0:
            for eot_token_id in generation_config.eot_token_ids:
                logits[:, :, eot_token_id] -= generation_config.eot_confidence_penalty

        # Remasking
        if generation_config.alg == "low_confidence":
            p = F.softmax(logits, dim=-1)
            x0_p = torch.squeeze(
                torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1
            )
        elif generation_config.alg == "random":
            x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
        else:
            raise NotImplementedError(generation_config.alg)

        # Restrict transfer candidates to the active generation block.
        if block_start > 0:
            x0_p[:, :block_start] = -np.inf
        if block_end < x0_p.shape[1]:
            x0_p[:, block_end:] = -np.inf

        x0 = torch.where(mask_index, x0, x)
        confidence = torch.where(mask_index, x0_p, -np.inf)

        transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
        for j in range(confidence.shape[0]):
            k = int(num_transfer_tokens[j, step_in_block].item())
            if k == 0:
                continue
            _, select_index = torch.topk(confidence[j], k=k)
            transfer_index[j, select_index] = True

        remasker = LladaRemasker(x=x, transfer_index=transfer_index)
        if _DEBUG_DEMASK:
            remasker._debug_transfer_count = int(transfer_index.sum().item())
        return x0, remasker


class LladaRemasker(Remasker):
    def __init__(self, x: Tensor, transfer_index: Tensor):
        self.x = x
        self.transfer_index = transfer_index

    @torch.no_grad()
    def step(self, x_0: Tensor) -> Tensor:
        self.x[self.transfer_index] = x_0[self.transfer_index]
        return self.x
