import dataclasses
from logging import getLogger
from typing import Literal, TypeAlias

from transformers import PreTrainedTokenizer

logger = getLogger(__name__)

MessagesType: TypeAlias = list[dict[Literal["user", "system"], str]]

TOKEN_ID_MAPPER: dict[str, dict[Literal["mask", "pad", "eos", "eot"], int]] = {
    "GSAI-ML/LLaDA-8B-Instruct": {
        "mask": 126336,
        "pad": 126081,
        "eos": 126081,
        "eot": 126348,
    },
    "Dream-org/Dream-Coder-v0-Instruct-7B": {
        "mask": 151666,
        "pad": 151643,
        "eos": 151643,
        "eot": 151645,
    },
    "Dream-org/Dream-v0-Instruct-7B": {
        "mask": 151666,
        "pad": 151643,
        "eos": 151643,
        "eot": 151645,
    },
    # diffucoder uses Dream tokenizer, but added specific trained pad token <|dlm_pad|>
    "apple/DiffuCoder-7B-cpGRPO": {
        "mask": 151666,
        "pad": 151667,
        "eos": 151643,
        "eot": 151645,
    },
}


@dataclasses.dataclass
class TokenSegment:
    """
    We usually only need text, mask, pad, eot.
    NOTE: eot indicates the end of turn, while eos stands for the end of sentences.
    As of now, for non-diffucoder models, eos and pad tokens are the same.
    """

    kind: Literal["text", "mask", "pad", "eos", "eot"]
    content: str = ""
    repetition: int = 1

    def num_tokens(self, tokenizer: PreTrainedTokenizer) -> int:
        if self.kind == "text":
            return len(tokenizer.encode(self.content))
        return 1

    def to_token_ids(self, tokenizer: PreTrainedTokenizer) -> list[int]:
        model_name = tokenizer.name_or_path
        if model_name not in TOKEN_ID_MAPPER:
            raise RuntimeError(
                f"token ids of model_name {model_name} are not mapped yet. Please add model to the TOKEN_ID_MAPPER."
            )

        if self.kind == "text":
            return tokenizer.encode(self.content)
        else:
            return [
                TOKEN_ID_MAPPER[model_name][self.kind] for _ in range(self.repetition)
            ]

    @classmethod
    def from_token_ids(
        cls, tokenizer: PreTrainedTokenizer, token_ids: list[int]
    ) -> tuple[list["TokenSegment"], int]:
        segments: list["TokenSegment"] = []
        num_masks = 0
        assert tokenizer.name_or_path in TOKEN_ID_MAPPER, (
            f"{tokenizer.name_or_path} is not mapped by TOKEN_ID_MAPPER"
        )
        id_to_kind = {
            id: kind for kind, id in TOKEN_ID_MAPPER[tokenizer.name_or_path].items()
        }

        tids_buffer: list[int] = []
        for token_id in token_ids:
            if token_id not in id_to_kind:
                tids_buffer.append(token_id)
                continue
            kind: str = id_to_kind[token_id]
            if len(tids_buffer) > 0:
                segments.append(cls(kind="text", content=tokenizer.decode(tids_buffer)))
                tids_buffer.clear()
            segments.append(cls(kind=kind))
            if kind == "mask":
                num_masks += 1

        if len(tids_buffer) > 0:
            text = tokenizer.decode(tids_buffer)
            segments.append(cls(kind="text", content=text))
            tids_buffer.clear()

        return segments, num_masks


@dataclasses.dataclass
class TokenSequence:
    """
    This class is responsible for bridging the token representations among different tokenizers. Also it does token sequence truncation if necessary.

    The sequence structure is for example like:
    <mask><generated_content><mask><generated_content><eot><pad><pad><pad>...

    We store prompt_msgs to account for the different chat template style (e.g., Llada follows Llama3 and Dream follows ChatML)
    The mask and generation segments are stored as IR (TokenSegment) which consits of the sequence of text, mask, eot and pad tokens.
    """

    prompt_msgs: MessagesType
    segments: list[TokenSegment]
    gen_length: int
    num_masks: int
    # Optional cached shape information for visualization/export.
    # Indices are in the generated token segment (i.e., not including the prompt tokens).
    generated_token_count: int | None = None
    mask_positions: list[int] | None = None

    prompt_tokens_cache: dict[str, list[int]] = dataclasses.field(
        default_factory=dict[str, list[int]]
    )

    @classmethod
    def init_diffusion_input(
        cls,
        prompt_msgs: MessagesType,
        gen_length: int,
        tokenizer: PreTrainedTokenizer | None = None,
        overwrite_text: str | None = None,
        overwrite_position: int | None = None,
    ) -> "TokenSequence":
        """
        Initialize diffusion model input, optionally with overwrite at some position of mask.
        overwrite_position should be negative.
        """
        if overwrite_position is not None:
            if overwrite_position >= 0:
                raise RuntimeError(
                    f"specify overwrite position by negative value; {overwrite_position} was given"
                )
            if overwrite_text is None:
                raise RuntimeError("specify overwrite text")
            if tokenizer is None:
                raise RuntimeError("specify tokenizer to count the tokens")
            token_len = len(tokenizer.encode(overwrite_text))
            if token_len > abs(overwrite_position):
                raise RuntimeError(
                    f"token length of overwrite text {overwrite_text} is {token_len}, and cannot be put at the specified position {overwrite_position}"
                )

            segments: list[TokenSegment] = []
            segments.append(
                TokenSegment(
                    kind="mask", repetition=gen_length - abs(overwrite_position)
                )
            )
            segments.append(
                TokenSegment(
                    kind="text",
                    content=overwrite_text,
                )
            )
            if token_len < abs(overwrite_position):
                segments.append(
                    TokenSegment(
                        kind="mask", repetition=abs(overwrite_position) - token_len
                    )
                )

            return cls(
                prompt_msgs,
                segments=segments,
                gen_length=gen_length - token_len,
                num_masks=gen_length - token_len,
                generated_token_count=None,
                mask_positions=None,
            )

        masks = TokenSegment(kind="mask", repetition=gen_length)
        return cls(
            prompt_msgs,
            segments=[masks],
            gen_length=gen_length,
            num_masks=gen_length,
            generated_token_count=int(gen_length),
            mask_positions=list(range(int(gen_length))),
        )

    def to_token_ids(
        self,
        tokenizer: PreTrainedTokenizer,
        truncate_tid_strategy: Literal["right_pad_mask_remove"] | None = None,
        max_tid_len: int = -1,
    ) -> tuple[int, list[int]]:
        # prompt
        token_ids = list(self._tokenize_prompt_msgs(tokenizer))
        prompt_len = len(token_ids)

        # mask and generated seqments
        for segment in self.segments:
            token_ids += segment.to_token_ids(tokenizer)

        if truncate_tid_strategy == "right_pad_mask_remove":
            model_name = tokenizer.name_or_path
            if model_name not in TOKEN_ID_MAPPER:
                raise RuntimeError(
                    f"token ids of model_name {model_name} are not mapped yet. "
                    "Please add model to the TOKEN_ID_MAPPER."
                )
            pad_id = TOKEN_ID_MAPPER[model_name]["pad"]
            mask_id = TOKEN_ID_MAPPER[model_name]["mask"]
            return self._truncate_right_pad_mask_remove(
                token_ids=token_ids,
                prompt_len=prompt_len,
                max_tid_len=max_tid_len,
                pad_id=pad_id,
                mask_id=mask_id,
            )

        return prompt_len, token_ids

    def _tokenize_prompt_msgs(self, tokenizer: PreTrainedTokenizer) -> list[int]:
        """We don't want to tokenize prompt_msgs every time, so let's cache it"""
        model_name: str = tokenizer.name_or_path
        if model_name not in self.prompt_tokens_cache:
            self.prompt_tokens_cache[model_name] = tokenizer.apply_chat_template(
                self.prompt_msgs,
                add_generation_prompt=True,
                tokenize=True,
                return_tensors=None,
            )
        return self.prompt_tokens_cache[model_name]

    def build_from_generated_token_ids(
        self, tokenizer: PreTrainedTokenizer, generated_token_ids: list[int]
    ) -> "TokenSequence":
        segments, num_masks = TokenSegment.from_token_ids(
            tokenizer, token_ids=generated_token_ids
        )
        mask_positions = None
        generated_token_count = None
        try:
            model_name = tokenizer.name_or_path
            mask_id = TOKEN_ID_MAPPER[model_name]["mask"]
            generated_token_count = len(generated_token_ids)
            mask_positions = [
                idx
                for idx, token_id in enumerate(generated_token_ids)
                if token_id == mask_id
            ]
        except Exception:
            # Best-effort only; visualization can fall back to aggregated mask counts.
            mask_positions = None
            generated_token_count = len(generated_token_ids)
        return TokenSequence(
            prompt_msgs=self.prompt_msgs,
            segments=segments,
            prompt_tokens_cache=self.prompt_tokens_cache,
            gen_length=self.gen_length,
            num_masks=num_masks,
            generated_token_count=generated_token_count,
            mask_positions=mask_positions,
        )

    @staticmethod
    def _remove_rightmost_pad_mask_anywhere(
        token_ids: list[int],
        prompt_len: int,
        to_remove: int,
        pad_id: int,
        mask_id: int,
    ) -> list[int]:
        if to_remove <= 0:
            return token_ids

        removed = 0
        kept_rev: list[int] = []

        # scan from right, remove rightmost pad/mask even if interspersed
        for i in range(len(token_ids) - 1, -1, -1):
            t = token_ids[i]
            if i >= prompt_len and removed < to_remove and t in (pad_id, mask_id):
                removed += 1
                continue
            kept_rev.append(t)

        kept_rev.reverse()
        return kept_rev

    @classmethod
    def _truncate_right_pad_mask_remove(
        cls,
        token_ids: list[int],
        prompt_len: int,
        max_tid_len: int,
        pad_id: int,
        mask_id: int,
    ) -> tuple[int, list[int]]:
        if max_tid_len <= 0:
            raise RuntimeError(f"max_tid_len must be positive, got {max_tid_len}")

        overflow = len(token_ids) - max_tid_len
        if overflow <= 0:
            return prompt_len, token_ids

        if overflow > 0:
            # 1) remove pad/mask from the right, even across other kinds
            new_ids = cls._remove_rightmost_pad_mask_anywhere(
                token_ids=token_ids,
                prompt_len=prompt_len,
                to_remove=overflow,
                pad_id=pad_id,
                mask_id=mask_id,
            )
            removed_cnt = len(token_ids) - len(new_ids)
            token_ids = new_ids
            overflow -= removed_cnt

        if overflow > 0:
            # 2) still overflow => trim prompt from the left (keep the rightmost prompt part)
            cut_prompt = min(overflow, prompt_len)
            if cut_prompt:
                token_ids = token_ids[cut_prompt:]
                prompt_len -= cut_prompt
                overflow -= cut_prompt

        if overflow > 0:
            # 3) cannot fit without cutting non-(pad/mask) generated tokens
            raise RuntimeError(
                "Input is longer than max_tid_len even after removing suffix pad/mask and trimming prompt. "
                "You need another strategy (e.g., allow cutting generated tokens) or increase max_tid_len."
            )
        return prompt_len, token_ids

    @property
    def mask_fraction(self) -> float:
        return self.num_masks / self.gen_length
