from typing import (
    Any,
    List,
)

import torch

from nptyping import (
    Float32,
    NDArray,
)
from transformers import (
    AutoModel,
    AutoTokenizer,
)

from src.vector_db.embedding_fn.base_embedding_fn import BaseEmbeddingFunction


class DRAGONEmbeddingFunction(BaseEmbeddingFunction):
    def __init__(self) -> None:
        query_encoder_model_name = "nvidia/dragon-multiturn-query-encoder"
        context_encoder_model_name = "nvidia/dragon-multiturn-context-encoder"

        self.tokenizer = AutoTokenizer.from_pretrained(query_encoder_model_name)
        self.query_encoder = AutoModel.from_pretrained(query_encoder_model_name)
        self.context_encoder = AutoModel.from_pretrained(context_encoder_model_name)

        sample_input = self.tokenizer(
            "sample",
            max_length=512,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )
        with torch.no_grad():
            self._dim = self.query_encoder(**sample_input).last_hidden_state.size(-1)

    def embed_documents(self, documents: List[str]) -> NDArray[Any, Float32]:
        ctx_input = self.tokenizer(
            documents,
            max_length=512,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )
        with torch.no_grad():
            ctx_emb = self.context_encoder(**ctx_input).last_hidden_state[:, 0, :]  # Extract [CLS] token embedding
        ctx_emb_npy = ctx_emb.cpu().numpy().astype("float32")
        return ctx_emb_npy

    def embed_queries(self, queries: List[str]) -> NDArray[Any, Float32]:
        query_input = self.tokenizer(
            queries,
            max_length=512,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )
        with torch.no_grad():
            query_emb = self.query_encoder(**query_input).last_hidden_state[:, 0, :]  # Extract [CLS] token embedding
        query_emb_npy = query_emb.cpu().numpy().astype("float32")
        return query_emb_npy

    @property
    def dim(self) -> int:
        return self._dim
    
    @property
    def metric_type(self) -> str:
        return "IP"
