import numpy as np
import numpy.typing as npt
import torch
from sentence_transformers import SentenceTransformer


import logging

from .latent_encoder import LatentEncoder

EncodingType = torch.Tensor | npt.NDArray[np.float_]


class EmbeddingEncoder(LatentEncoder[EncodingType, None]):
    """Latent Encoder to map strings to semantic embeddings"""

    def __init__(self):
        super().__init__()
        self.model = SentenceTransformer("google/embeddinggemma-300m")

    def encode(self, question: str, answer: str) -> EncodingType:
        """Encode answer with text embedding model. question is ignored."""
        return self.model.encode(answer, prompt_name="STS")

    def __call__(
        self, question: str, answers: dict[str, list[str]]
    ) -> tuple[dict[str, list[EncodingType]], None]:
        """Convert orchestrator run results into aligned latent encodings."""
        logging.debug(f"Encoding answers with EmbeddingEncoder answers : {answers}")
        encoded = {
            key: self.batch_encode([question] * len(ans_list), ans_list)
            for key, ans_list in answers.items()
        }
        return encoded, None
