from typing import cast

import torch
from transformers import BatchEncoding, PreTrainedModel
from transformers.utils import ModelOutput

from lib_llm.inference import PredictionConfig, predict


def test_predict():
    class DummyModel(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.device = torch.device("cpu")

        def forward(self, input_ids, attention_mask):
            return ModelOutput(logits=torch.ones((*input_ids.shape, 2)))

        def to(self, device):
            return self

    torch.manual_seed(853)
    batches = [
        BatchEncoding(
            {
                "input_ids": torch.randint(100, 1000, (32, 20)),
                "attention_mask": torch.ones((32, 20)),
            }
        )
        for _ in range(4)
    ]
    model = cast(PreTrainedModel, DummyModel())

    config = PredictionConfig(
        batch_size=10,
    )
    outputs = list(
        predict(
            model,
            batches,
            config,
        )
    )
    assert len(outputs) == 4
    for batch, batch_output in zip(batches, outputs):
        batch_logits = batch_output["logits"]
        assert batch_logits.shape == (batch.input_ids.shape[0], 19, 2)
