from typing import Any

import torch
from transformers import BatchEncoding, TokenizersBackend


class SimpleEncoder:
    def __init__(self, tokenizer: TokenizersBackend, default_system_prompt: str | None = None):
        self.tokenizer = tokenizer
        self.default_system_prompt = default_system_prompt

    def __call__(self, batch: dict[str, list[Any]]) -> dict[str, torch.Tensor]:
        final_input = self.format_chat(batch)
        return self.tokenize(final_input)

    def format_chat(
        self, batch: dict[str, list[Any]]
    ) -> dict[str, str | list[Any] | BatchEncoding]:
        messages = batch["messages"]
        if self.default_system_prompt is not None:
            messages = self._inject_system_prompt(messages)
        try:
            final_input = self.tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=False,
            )
        except ValueError:
            assert all(len(item) == 1 for item in messages), (
                f"Expected single message in batch, got {messages}"
            )
            final_input = [item[0]["content"].rstrip() for item in messages]
        return {"final_input": final_input}

    def _inject_system_prompt(
        self, messages: list[list[dict[str, str]]]
    ) -> list[list[dict[str, str]]]:
        """Prepend system message to each conversation."""
        assert self.default_system_prompt is not None
        result: list[list[dict[str, str]]] = []
        for conv in messages:
            assert conv and conv[0].get("role") == "user", (
                f"Expected first message role to be 'user', got '{conv[0].get('role')}'"
            )
            system_msg: dict[str, str] = {
                "role": "system",
                "content": self.default_system_prompt,
            }
            result.append([system_msg] + conv)
        return result

    def tokenize(
        self, batch: dict[str, str | list[Any] | BatchEncoding]
    ) -> dict[str, torch.Tensor]:
        assert "final_input" in batch, "Final input not found in batch"
        return self.tokenizer(
            batch["final_input"],
            return_tensors="pt",
            padding="longest",
            truncation=False,
        )
