from typing import (
    Any,
    List,
)

import torch

from nptyping import (
    Float32,
    NDArray,
)
from transformers import (
    AutoModel,
    AutoTokenizer,
)

from src.vector_db.embedding_fn.base_embedding_fn import BaseEmbeddingFunction


def mean_pooling(token_embeddings, mask):
    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.0)
    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
    return sentence_embeddings


class ContrieverEmbeddingFunction(BaseEmbeddingFunction):
    def __init__(self) -> None:
        self.model_name = "facebook/contriever"

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModel.from_pretrained(self.model_name)

        sample_input = self.tokenizer(
            "sample",
            max_length=512,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )
        with torch.no_grad():
            self._dim = self.model(**sample_input).last_hidden_state.size(-1)

    def embed_documents(self, documents: List[str]) -> NDArray[Any, Float32]:
        inputs = self.tokenizer(
            documents,
            max_length=512,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )
        with torch.no_grad():
            outputs = self.model(**inputs)
        embeddings = mean_pooling(outputs[0], inputs["attention_mask"])
        embeddings_npy = embeddings.cpu().numpy().astype("float32")
        return embeddings_npy

    def embed_queries(self, queries: List[str]) -> NDArray[Any, Float32]:
        inputs = self.tokenizer(
            queries,
            max_length=512,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )
        with torch.no_grad():
            outputs = self.model(**inputs)
        embeddings = mean_pooling(outputs[0], inputs["attention_mask"])
        embeddings_npy = embeddings.cpu().numpy().astype("float32")
        return embeddings_npy

    @property
    def dim(self) -> int:
        return self._dim

    @property
    def metric_type(self) -> str:
        return "IP"
