from typing import (
    Any,
    List,
)

import torch

from nptyping import (
    Float32,
    NDArray,
)
from transformers import (
    DPRContextEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoder,
    DPRQuestionEncoderTokenizer,
)

from src.utils.array import batched
from src.utils.torch import get_available_device
from src.vector_db.embedding_fn.base_embedding_fn import BaseEmbeddingFunction


class DPREmbeddingFunction(BaseEmbeddingFunction):
    def __init__(self) -> None:
        self.ctx_encoder_model_name = "facebook/dpr-ctx_encoder-single-nq-base"
        self.question_encoder_model_name = "facebook/dpr-question_encoder-single-nq-base"
        self.device = get_available_device()

        # Safe to ignore the warning when loading tokenizer: see https://github.com/huggingface/transformers/issues/12926
        self.ctx_encoder_tokenizer = DPRContextEncoderTokenizer.from_pretrained(self.ctx_encoder_model_name)
        self.ctx_encoder = DPRContextEncoder.from_pretrained(self.ctx_encoder_model_name).to(self.device)

        self.question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(self.question_encoder_model_name)
        self.question_encoder = DPRQuestionEncoder.from_pretrained(self.question_encoder_model_name).to(self.device)

        sample_input = self.ctx_encoder_tokenizer(
            "sample",
            max_length=512,
            padding=True,
            truncation=True,
            return_tensors="pt",
        ).to(self.device)
        with torch.no_grad():
            self._dim = self.ctx_encoder(**sample_input).pooler_output.size(-1)

    def embed_documents(self, documents: List[str]) -> NDArray[Any, Float32]:
        input_ids = self.ctx_encoder_tokenizer(
            documents,
            max_length=512,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )["input_ids"]
        with torch.no_grad():
            embeddings = []
            for input_ids_batch in batched(input_ids, 64):
                embeddings.append(self.ctx_encoder(input_ids_batch.to(self.device)).pooler_output)
        embeddings = torch.cat(embeddings, dim=0)
        embeddings_npy = embeddings.cpu().numpy().astype("float32")
        assert embeddings_npy.shape[0] == input_ids.shape[0]
        return embeddings_npy

    def embed_queries(self, queries: List[str]) -> NDArray[Any, Float32]:
        input_ids = self.question_encoder_tokenizer(
            queries,
            max_length=512,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )["input_ids"]
        with torch.no_grad():
            embeddings = []
            for input_ids_batch in batched(input_ids, 64):
                embeddings.append(self.question_encoder(input_ids_batch.to(self.device)).pooler_output)
        embeddings = torch.cat(embeddings, dim=0)
        embeddings_npy = embeddings.cpu().numpy().astype("float32")
        return embeddings_npy

    @property
    def dim(self) -> int:
        return self._dim
    
    @property
    def metric_type(self) -> str:
        return "IP"
