import os
from typing import (
    Any,
    List,
)

import torch

from nptyping import (
    Float32,
    NDArray,
)
from transformers import (
    AutoModel,
    AutoTokenizer,
)

from src.vector_db.embedding_fn.base_embedding_fn import BaseEmbeddingFunction
from src.utils.torch import get_available_device
from src.utils.array import batched


os.environ["TOKENIZERS_PARALLELISM"] = "false"


class DRAGONEmbeddingFunction(BaseEmbeddingFunction):
    def __init__(self) -> None:
        self.query_encoder_model_name = "nvidia/dragon-multiturn-query-encoder"
        self.context_encoder_model_name = "nvidia/dragon-multiturn-context-encoder"

        self.device = get_available_device()

        self.tokenizer = AutoTokenizer.from_pretrained(self.query_encoder_model_name)
        self.query_encoder = AutoModel.from_pretrained(self.query_encoder_model_name).to(self.device)
        self.context_encoder = AutoModel.from_pretrained(self.context_encoder_model_name).to(self.device)

        self.tokenizer.trancation_side = "right"
        sample_input = self.tokenizer(
            "sample",
            max_length=512,
            padding=True,
            padding_side="right",
            truncation=True,
            return_tensors="pt",
        ).to(self.device)
        with torch.no_grad():
            self._dim = self.context_encoder(**sample_input).last_hidden_state.size(-1)

    def embed_documents(self, documents: List[str]) -> NDArray[Any, Float32]:
        self.tokenizer.trancation_side = "right"
        inputs = self.tokenizer(
            documents,
            max_length=512,
            padding=True,
            padding_side="right",
            truncation=True,
            return_tensors="pt",
        )
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        with torch.no_grad():
            embeddings = []
            for input_ids_batch, attention_mask_batch in zip(batched(input_ids, 64), batched(attention_mask, 64)):
                embeddings.append(
                    self.context_encoder(
                        input_ids_batch.to(self.device),
                        attention_mask=attention_mask_batch.to(self.device),
                    ).last_hidden_state[:, 0, :],
                )  # Extract [CLS] token embedding
        embeddings = torch.cat(embeddings, dim=0)
        embeddings_npy = embeddings.cpu().numpy().astype("float32")
        assert embeddings_npy.shape[0] == input_ids.shape[0]
        return embeddings_npy

    def embed_queries(self, queries: List[str]) -> NDArray[Any, Float32]:
        self.tokenizer.trancation_side = "left"
        inputs = self.tokenizer(
            queries,
            max_length=512,
            padding=True,
            padding_side="left",
            truncation=True,
            return_tensors="pt",
        )
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        with torch.no_grad():
            embeddings = []
            for input_ids_batch, attention_mask_batch in zip(batched(input_ids, 64), batched(attention_mask, 64)):
                embeddings.append(
                    self.query_encoder(
                        input_ids_batch.to(self.device),
                        attention_mask=attention_mask_batch.to(self.device),
                    ).last_hidden_state[:, 0, :]
                )  # Extract [CLS] token embedding
        embeddings = torch.cat(embeddings, dim=0)
        embeddings_npy = embeddings.cpu().numpy().astype("float32")
        return embeddings_npy

    @property
    def dim(self) -> int:
        return self._dim
    
    @property
    def metric_type(self) -> str:
        return "IP"
