import torch

from torch import Tensor
from typing import Callable
from transformers.cache_utils import Cache

from .lm import LM


class SingleTokenLM(torch.nn.Module):
    """A SingleTokenLM is a wrapper around a LM that gives us the same
    interface as a MultiTokenLM.

    1. A LM encoder, which can be the encoder (i.e. arch without lm_head)
    of any pretrained LLM. The encoder provides contextual embeddings for
    tokens.

    """

    def __init__(self, lm: LM):
        super().__init__()
        self.lm = lm

    def forward(
        self,
        input_ids: torch.Tensor,  # (B, S) input ids
        labels: torch.Tensor,  # (B, S) target ids
        attention_mask: torch.Tensor | None,  # (B, S) target ids
        return_logits: bool = False,
    ) -> dict:
        return self.lm(input_ids=input_ids, labels=labels, attention_mask=attention_mask, return_logits=return_logits)

    @torch.no_grad()
    def generate(
        self,
        inputs: torch.Tensor,
        use_argmax: bool = False,
        mode: str = "stp",
        use_cache: bool = False,
        past_key_values: Cache = None,
        draft_top_p: float = 1.0,
        logit_processor: Callable = None,
    ) -> Tensor:
        return self.lm.generate(
            inputs,
            use_argmax=use_argmax,
            mode=mode,
            use_cache=use_cache,
            draft_top_p=draft_top_p,
            past_key_values=past_key_values,
            logit_processor=logit_processor,
        )
