import logging
from typing import Iterator, Optional

import torch
from transformers import (
    BatchEncoding,
    PreTrainedTokenizer,
    PreTrainedTokenizerBase,
)


logger = logging.getLogger(__name__)


def tokenize(
    tokenizer: PreTrainedTokenizerBase,
    text: list[str],
    max_length: Optional[int] = None,
) -> BatchEncoding:
    if tokenizer.padding_side == "right":
        logger.warning("Padding side is right, setting it to left")
        tokenizer.padding_side = "left"
    if max_length is None:
        padding = "longest"
    else:
        padding = "max_length"
    return tokenizer(
        text,
        return_tensors="pt",
        return_token_type_ids=False,
        truncation=True,
        padding=padding,
        max_length=max_length,
    )


def iter_masked(
    # encoding: BatchEncoding,
    items: torch.Tensor,
    mask: torch.Tensor,
) -> Iterator[torch.Tensor]:
    binary_mask = mask.bool()
    yield from (
        encoding_row_ids[encoding_row_mask]
        for encoding_row_ids, encoding_row_mask in zip(items, binary_mask)
    )
