import abc

from torch import Tensor


class Remasker(abc.ABC):
    @abc.abstractmethod
    def step(self, x_0: Tensor) -> Tensor:
        raise NotImplementedError()


class DemaskStepper(abc.ABC):
    def step(
        self, x: Tensor, step_idx: int, attention_mask: Tensor | None = None
    ) -> Tensor:
        """
        One denoising step of Diffusion Language Model.
        First, predict clean sample (x_0) where **all** the [MASK] positions are filled by some token. Also optionally return "remasker", later used for remask.
        Secondly using "remasker", remask some (usually fixed) fraction of x_0. For example, we may want to remask low confidence token positions.

        step_idx: Increases as 0,1,2,...,T-1 as the demask progresses.

        attention_mask is used for batch processing purpose. If attention_mask is set to None, all the tokens are attended.
        To get a valid attention_mask from batch, use tokenizer as:
        ```python
        inputs = tokenizer.apply_chat_template(
            messages, return_tensors="pt", return_dict=True, add_generation_prompt=True
        )
        input_ids = inputs.input_ids.to(device="cuda")
        attention_mask = inputs.attention_mask.to(device="cuda")
        ```
        """
        x_0, remasker = self.predict_x0(
            x, step_idx=step_idx, attention_mask=attention_mask
        )
        return remasker.step(x_0=x_0)

    @abc.abstractmethod
    def predict_x0(
        self, x: Tensor, step_idx: int, attention_mask: Tensor | None = None
    ) -> tuple[Tensor, Remasker]:
        """
        demask all tokens to calculate,
        (1) x_0 (prediction of all the [MASK] tokens)
        (2) remasker (may be later used to remask tokens)
        """
        raise NotImplementedError()
