"""
Reimplementation of
https://huggingface.co/Dream-org/Dream-Coder-v0-Instruct-7B/blob/main/generation_utils.py
"""

import os
from dataclasses import dataclass
from typing import Literal

import torch
import torch.distributions as dists
from torch import Tensor
from torch.nn import functional as F
from transformers import PreTrainedModel

from bd_mcts.demask_interface.base import DemaskStepper, Remasker

AlgType = Literal["origin", "entropy", "maskgit_plus", "topk_margin"]

_ALLOW_MASK_TOKEN = os.environ.get("BD_MCTS_DEMASK_ALLOW_MASK_TOKEN", "").strip().lower() in (
    "1",
    "true",
    "yes",
    "y",
    "on",
)
_DEBUG_DEMASK = os.environ.get("BD_MCTS_DEBUG_DEMASK", "").strip().lower() in (
    "1",
    "true",
    "yes",
    "y",
    "on",
)


@dataclass
class DreamGenerationConfig:
    # number of diffusion steps
    steps: int = 512
    alg: AlgType = "origin"
    alg_temp: float | None = None
    eos_penalty: float = 0.0

    # sampling
    temperature: float = 0.0
    top_p: float | None = None
    top_k: int | None = None
    max_length: int = 20
    max_new_tokens: int | None = None
    eps: float = 1e-3  # timestep of all clean tokens, (i.e., t=1~eps)

    mask_token_id: int = -1
    pad_token_id: int = -1
    bos_token_id: int = -1
    eos_token_id: int = -1

    def __post_init__(self):
        assert self.mask_token_id >= 0
        assert self.pad_token_id >= 0
        assert self.bos_token_id >= 0
        assert self.eos_token_id >= 0


def top_p_logits(logits: Tensor, top_p: float):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = cumulative_probs > top_p
    # Shift the indices to the right to keep the first token above the threshold
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
    mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
    return logits


def top_k_logits(logits: Tensor, top_k: int):
    top_k = min(top_k, logits.size(-1))  # Safety check
    # Remove all tokens with a probability less than the last token of the top-k
    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
    return logits


@torch.no_grad
def sample_tokens(
    logits: Tensor,
    temperature: float = 0.0,
    top_p: float | None = None,
    top_k: int | None = None,
    margin_confidence: bool = False,
    neg_entropy: bool = False,
):
    original_dtype = logits.dtype
    logits = logits.to(torch.float32)
    if temperature > 0:
        logits = logits / temperature
    if top_p is not None and top_p < 1:
        logits = top_p_logits(logits, top_p)
    if top_k is not None:
        logits = top_k_logits(logits, top_k)
    probs = torch.softmax(logits, dim=-1)

    if temperature > 0:
        x0 = dists.Categorical(probs=probs).sample()
        confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
    else:
        confidence, x0 = probs.max(dim=-1)

    if margin_confidence:
        sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
        # Extract top1 and top2 probabilities
        top1_probs = sorted_probs[:, 0]
        top2_probs = sorted_probs[:, 1]
        # Calculate confidence as top1 - top2
        confidence = top1_probs - top2_probs

    if neg_entropy:
        epsilon = 1e-10
        log_probs = torch.log(probs + epsilon)
        confidence = torch.sum(probs * log_probs, dim=-1)

    return confidence.to(original_dtype), x0


class DreamRemasker(Remasker):
    def __init__(
        self,
        confidence: Tensor | None,
        mask_index: Tensor,
        x: Tensor,
        model: PreTrainedModel,
        generation_config: "DreamGenerationConfig",
        t: Tensor,
        s: Tensor,
        step_idx: int,
    ):
        self.confidence = confidence
        self.mask_index = mask_index
        self.x = x
        self.model = model
        self.generation_config = generation_config

        self.t = t
        self.s = s
        self.step_idx = step_idx

    @torch.no_grad
    def step(self, x_0: Tensor) -> Tensor:
        if self.generation_config.alg == "origin":
            if _DEBUG_DEMASK:
                self._debug_transfer_count = int(self.mask_index.sum().item())
            self.x[self.mask_index] = x_0.clone()
            return self.x

        assert self.confidence is not None, (
            "Internal Error! For non-origin alg, confidence should be not None"
        )

        num_mask_token = self.mask_index.sum() / self.mask_index.shape[0]
        full_confidence = torch.full_like(
            self.x, -torch.inf, device=self.model.device, dtype=self.confidence.dtype
        )
        number_transfer_tokens = (
            int(num_mask_token * (1 - self.s / self.t))
            if self.step_idx < self.generation_config.steps - 1
            else int(num_mask_token)
        )
        if _DEBUG_DEMASK:
            self._debug_transfer_count = int(number_transfer_tokens)
        full_confidence[self.mask_index] = self.confidence
        if number_transfer_tokens > 0:
            if (
                self.generation_config.alg_temp is None
                or self.generation_config.alg_temp == 0
            ):
                _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
            else:
                full_confidence = full_confidence / self.generation_config.alg_temp
                full_confidence = F.softmax(full_confidence, dim=-1)
                transfer_index = torch.multinomial(
                    full_confidence, num_samples=number_transfer_tokens
                )
            x_ = (
                torch.zeros_like(self.x, device=self.model.device, dtype=torch.long)
                + self.generation_config.mask_token_id
            )
            x_[self.mask_index] = x_0.clone()
            if _DEBUG_DEMASK:
                self._debug_selected_mask_count = int(
                    (
                        x_[
                            torch.arange(self.x.size(0), device=self.model.device).unsqueeze(1),
                            transfer_index,
                        ]
                        == self.generation_config.mask_token_id
                    )
                    .sum()
                    .item()
                )
            row_indices = (
                torch.arange(self.x.size(0), device=self.model.device)
                .unsqueeze(1)
                .expand_as(transfer_index)
            )
            self.x[row_indices, transfer_index] = x_[row_indices, transfer_index]

        return self.x


class Dream(DemaskStepper):
    def __init__(self, model: PreTrainedModel, **kwargs) -> None:
        self.model = model
        self.generation_config = DreamGenerationConfig(**kwargs)

        self.timesteps = torch.linspace(
            1,
            self.generation_config.eps,
            self.generation_config.steps + 1,
            device=self.model.device,
        )

    @torch.no_grad
    def process_attn_mask_from_tokenizer(
        self, attention_mask: torch.Tensor | None, max_length: int
    ) -> tuple[torch.Tensor | Literal["full"], Tensor | None]:
        if attention_mask is not None and torch.any(attention_mask == 0.0):
            # we do not mask the [MASK] tokens so value = 1.0
            attention_mask = F.pad(
                attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0
            )
            tok_idx = attention_mask.long().cumsum(-1) - 1
            tok_idx.masked_fill_(attention_mask == 0, 1)
            # attention_mask is of shape [B, N]
            # broadcast to [B, 1, N, N]
            attention_mask = torch.logical_and(
                attention_mask.unsqueeze(1).unsqueeze(-2),
                attention_mask.unsqueeze(1).unsqueeze(-1),
            )
        else:
            return "full", None

        return attention_mask, tok_idx

    @torch.no_grad
    def predict_x0(
        self, x: Tensor, step_idx: int, attention_mask: Tensor | None = None
    ) -> tuple[Tensor, DreamRemasker]:
        attn_mask, tok_idx = self.process_attn_mask_from_tokenizer(
            attention_mask, max_length=self.generation_config.max_length
        )

        mask_token_id = self.generation_config.mask_token_id
        pad_token_id = self.generation_config.pad_token_id
        steps = self.generation_config.steps
        eps = self.generation_config.eps
        alg = self.generation_config.alg
        temperature = self.generation_config.temperature
        top_p = self.generation_config.top_p
        top_k = self.generation_config.top_k
        eos_penalty = self.generation_config.eos_penalty

        mask_index = x == mask_token_id
        logits = self.model(x, attn_mask, tok_idx).logits
        logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)

        mask_logits = logits[mask_index]
        if not _ALLOW_MASK_TOKEN:
            # The diffusion models are trained to predict clean tokens at masked positions.
            # Allowing the model to predict the mask token itself can stall demasking (mask count
            # never decreases), which breaks downstream algorithms that assume monotonic demasking.
            if 0 <= mask_token_id < mask_logits.shape[-1]:
                mask_logits[:, mask_token_id] = torch.finfo(mask_logits.dtype).min
        t = self.timesteps[step_idx]
        s = self.timesteps[step_idx + 1]

        mask_logits[:, pad_token_id] += eos_penalty * torch.log(1 - t + eps)
        if alg == "origin":
            p_transfer = 1 - s / t if step_idx < steps - 1 else 1
            x0 = (
                torch.zeros_like(
                    x[mask_index], device=self.model.device, dtype=torch.long
                )
                + mask_token_id
            )
            transfer_index_t_s = (
                torch.rand(*x0.shape, device=self.model.device) < p_transfer
            )
            _, x0[transfer_index_t_s] = sample_tokens(
                mask_logits[transfer_index_t_s],
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
            )
            remasker = DreamRemasker(
                confidence=None,
                mask_index=mask_index,
                x=x,
                model=self.model,
                generation_config=self.generation_config,
                t=t,
                s=s,
                step_idx=step_idx,
            )
            return x0, remasker
        else:
            if alg == "maskgit_plus":
                confidence, x0 = sample_tokens(
                    mask_logits, temperature=temperature, top_p=top_p, top_k=top_k
                )
            elif alg == "topk_margin":
                confidence, x0 = sample_tokens(
                    mask_logits,
                    temperature=temperature,
                    top_p=top_p,
                    top_k=top_k,
                    margin_confidence=True,
                )
            elif alg == "entropy":
                confidence, x0 = sample_tokens(
                    mask_logits,
                    temperature,
                    top_p=top_p,
                    top_k=top_k,
                    neg_entropy=True,
                )
            else:
                raise RuntimeError(f"Unknown alg: {alg}")

            remasker = DreamRemasker(
                confidence=confidence,
                mask_index=mask_index,
                x=x,
                model=self.model,
                generation_config=self.generation_config,
                t=t,
                s=s,
                step_idx=step_idx,
            )
            return x0, remasker
