import torch
from core.model import Preprocessor
from core.utils import th


type StopTokens = list[tuple[int, ...] | int | str] | tuple[int, ...] | int | str
"""
The subsequence (subsequences) that stops the inference process once detected.
    *tuple*: token indices of the subsequence.
    *int*: the index of the token in a subsequence that has only one token.
    *str*: the textual tokens in the subsequence.
    *list*: enumerate multiple items as above.
"""


def default_stop_tokens(preprocessor: Preprocessor) -> StopTokens:
    eos = preprocessor.tokenizer.eos_id
    if eos is None:
        return []
    else:
        return [eos]


def to_stop_seqs(preprocessor: Preprocessor, tokens: StopTokens | None) -> list[torch.Tensor]:
    device = preprocessor.device
    dtype = torch.int32

    if tokens is None:
        tokens = default_stop_tokens(preprocessor)
    if not isinstance(tokens, list):
        tokens = [tokens]
    return [
        preprocessor.encode(seq).to(dtype=dtype) if isinstance(seq, str)
        else torch.tensor(
            [seq] if isinstance(seq, int) else seq,
            dtype=dtype,
            device=device
        )
        for seq in tokens
    ]
