import os
from typing import (
    Any,
    List,
)

import torch

from nptyping import (
    Float32,
    NDArray,
)
from transformers import (
    DPRContextEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoder,
    DPRQuestionEncoderTokenizer,
)

from src.vector_db.embedding_fn.base_embedding_fn import BaseEmbeddingFunction
from src.utils.torch import get_available_device
from src.utils.array import batched


os.environ["TOKENIZERS_PARALLELISM"] = "false"


class DPREmbeddingFunction(BaseEmbeddingFunction):
    def __init__(self) -> None:
        self.ctx_encoder_model_name = "facebook/dpr-ctx_encoder-single-nq-base"
        self.question_encoder_model_name = "facebook/dpr-question_encoder-single-nq-base"

        self.device = get_available_device()

        # Safe to ignore the warning when loading tokenizer: see https://github.com/huggingface/transformers/issues/12926
        self.ctx_encoder_tokenizer = DPRContextEncoderTokenizer.from_pretrained(self.ctx_encoder_model_name)
        self.ctx_encoder = DPRContextEncoder.from_pretrained(self.ctx_encoder_model_name).to(self.device)
        self.ctx_encoder_tokenizer.truncation_side = 'right'

        self.question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(self.question_encoder_model_name)
        self.question_encoder = DPRQuestionEncoder.from_pretrained(self.question_encoder_model_name).to(self.device)
        self.question_encoder_tokenizer.truncation_side = 'left'

        sample_input = self.ctx_encoder_tokenizer(
            "sample",
            max_length=512,
            padding=True,
            padding_side="right",
            truncation=True,
            return_tensors="pt",
        ).to(self.device)
        with torch.no_grad():
            self._dim = self.ctx_encoder(**sample_input).pooler_output.size(-1)

    def embed_documents(self, documents: List[str]) -> NDArray[Any, Float32]:
        inputs = self.ctx_encoder_tokenizer(
            documents,
            max_length=512,
            padding=True,
            padding_side="right",
            truncation=True,
            return_tensors="pt",
        )
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        with torch.no_grad():
            embeddings = []
            for input_ids_batch, attention_mask_batch in zip(batched(input_ids, 64), batched(attention_mask, 64)):
                embeddings.append(
                    self.ctx_encoder(
                        input_ids_batch.to(self.device),
                        attention_mask=attention_mask_batch.to(self.device)
                    ).pooler_output
                )
        embeddings = torch.cat(embeddings, dim=0)
        embeddings_npy = embeddings.cpu().numpy().astype("float32")
        assert embeddings_npy.shape[0] == input_ids.shape[0]
        return embeddings_npy

    def embed_queries(self, queries: List[str]) -> NDArray[Any, Float32]:
        inputs = self.question_encoder_tokenizer(
            queries,
            max_length=512,
            padding=True,
            padding_side="left",
            truncation=True,
            return_tensors="pt",
        )
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        with torch.no_grad():
            embeddings = []
            for input_ids_batch, attention_mask_batch in zip(batched(input_ids, 64), batched(attention_mask, 64)):
                embeddings.append(
                    self.question_encoder(
                        input_ids_batch.to(self.device),
                        attention_mask=attention_mask_batch.to(self.device)
                    ).pooler_output
                )
        embeddings = torch.cat(embeddings, dim=0)
        embeddings_npy = embeddings.cpu().numpy().astype("float32")
        return embeddings_npy

    @property
    def dim(self) -> int:
        return self._dim
    
    @property
    def metric_type(self) -> str:
        return "IP"
