from typing import List, Dict, Tuple, Set
import numpy as np
import os
import time
import tqdm
import torch
import fasttext
from gensim.models import Word2Vec, KeyedVectors
from sentence_transformers import SentenceTransformer
import logging
import openai
from beir import LoggingHandler
from collections import Counter
from gensim.parsing.preprocessing import remove_stopwords, preprocess_string
from gensim.utils import simple_preprocess
import hashlib
from mistralai import Mistral
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger()
class EmbeddingGenerator:
    def __init__(self, model_name: str, dataset_name: str, cache_dir: str, device: torch.device, batch_size: int):
        self.model_name = model_name
        self.dataset_name = dataset_name
        self.pretrained_weights_dir = "./pretrained_weights"
        self.cache_dir = cache_dir
        self.device = device
        self.batch_size = batch_size
    def _get_cache_key(self, texts: List[str], batch_size: int, model_name: str, create_cache_folder: bool = True) -> str:
        cache_key_raw = (texts[0][:10] if texts else "") + str(len(texts)) + str(batch_size) + model_name
        cache_key = hashlib.md5(cache_key_raw.encode('utf-8')).hexdigest()
        cache_folder = os.path.join(self.cache_dir, "cache", cache_key)
        if create_cache_folder:
            os.makedirs(cache_folder, exist_ok=True)
        return cache_key, cache_folder
    def _generate_batch_embeddings(self, texts: List[str], batch_size: int) -> np.ndarray:
        raise NotImplementedError("You need to implement this method in the subclass.")
    def _generate_query_embeddings(self, texts: List[str], batch_size: int) -> np.ndarray:
        return self._generate_batch_embeddings(texts, batch_size)
    def generate_embeddings(self, text_list: List[str], cache_key: str, type: str = "corpus") -> List[float]:
        if os.path.exists(os.path.join(self.cache_dir, cache_key)):
            embeddings = np.load(os.path.join(self.cache_dir, cache_key))
            logger.info(f"Embeddings for {len(text_list)} texts loaded from {cache_key}.")
            return embeddings
        else:
            if type == "corpus":
                embeddings = self._generate_batch_embeddings(text_list, batch_size=self.batch_size)
            elif type == "query":
                embeddings = self._generate_query_embeddings(text_list, batch_size=self.batch_size)
            np.save(os.path.join(self.cache_dir, cache_key), embeddings)
            logger.info(f"Embeddings for {len(text_list)} texts generated and cached to {os.path.join(self.cache_dir, cache_key)}.")
            return embeddings
    def generate_corpus_embeddings(self, corpus: Dict[str, Dict[str, str]], cache_key: str) -> List[float]:
        text_list = [item['text'] for item in corpus.values()]
        return self.generate_embeddings(text_list, cache_key, type="corpus")
    def generate_query_embeddings(self, queries: Dict[str, str], cache_key: str) -> List[float]:
        text_list = [item for item in queries.values()]
        return self.generate_embeddings(text_list, cache_key, type="query")
class NVEmbedEmbeddingGenerator(EmbeddingGenerator):
    def __init__(self, cache_dir: str, dataset_name: str, device: torch.device, batch_size: int):
        super().__init__("NV-Embed", dataset_name, cache_dir, device, batch_size)
        self.model = SentenceTransformer('nvidia/NV-Embed-v2', trust_remote_code=True, cache_folder=self.pretrained_weights_dir)
    def _generate_batch_embeddings(self, texts: List[str], batch_size: int) -> np.ndarray:
        total_batches = (len(texts) + batch_size - 1) // batch_size
        cache_key, cache_folder = self._get_cache_key(texts, batch_size, self.model_name)
        eos_token = self.model.tokenizer.eos_token
        max_length = 512
        output_file = os.path.join(cache_folder, "all_embeddings.dat")
        first_batch = texts[:min(batch_size, len(texts))]
        first_batch = [text[:max_length] + eos_token for text in first_batch if text and isinstance(text, str)]
        if not first_batch:
            raise ValueError("No valid texts in the first batch")
        sample_emb = self.model.encode(
            first_batch[:1],
            batch_size=1,
            normalize_embeddings=True,
            show_progress_bar=False
        )
        embedding_dim = sample_emb.shape[1]
        if os.path.exists(output_file):
            fp = np.memmap(output_file, dtype='float32', mode='r', shape=(len(texts), embedding_dim))
            return fp
        fp = np.memmap(output_file, dtype='float32', mode='w+', 
                      shape=(len(texts), embedding_dim))
        for i in tqdm.tqdm(range(0, len(texts), batch_size), total=total_batches, desc="Generating embeddings"):
            batch_texts = texts[i:i + batch_size]
            batch_texts = [text for text in batch_texts if text and isinstance(text, str)]
            if not batch_texts:
                continue
            cache_file = os.path.join(cache_folder, f"{i}.npy")
            if os.path.exists(cache_file):
                batch_emb = np.load(cache_file)
            else:
                batch_texts = [text[:max_length] + eos_token for text in batch_texts]
                embeddings = self.model.encode(
                    batch_texts,
                    batch_size=batch_size,
                    normalize_embeddings=True,
                    show_progress_bar=False
                )
                np.save(cache_file, embeddings)
                batch_emb = embeddings
            fp[i:i + len(batch_emb)] = batch_emb
        fp.flush()
        del fp
        return np.memmap(output_file, dtype='float32', mode='r', shape=(len(texts), embedding_dim))
class FastTextEmbeddingGenerator(EmbeddingGenerator):
    def __init__(self, cache_dir: str, dataset_name: str, device: torch.device, batch_size: int):
        super().__init__("Fast-Text", dataset_name, cache_dir, device, batch_size)
        self.model = fasttext.load_model(f"cached_output/fast_text/{self.dataset_name}_fasttext.model")
    def _generate_batch_embeddings(self, texts: List[str], batch_size: int) -> np.ndarray:
        texts = [text.replace("\n", " ") for text in texts]
        return np.array([self.model.get_sentence_vector(text) for text in texts])
class Word2VecEmbeddingGenerator(EmbeddingGenerator):
    def __init__(self, cache_dir: str, dataset_name: str, device: torch.device, batch_size: int):
        super().__init__("Word2Vec", dataset_name, cache_dir, device, batch_size)
        model_path = os.path.expanduser("~/GoogleNews-vectors-negative300.bin")
        self.model = KeyedVectors.load_word2vec_format(model_path, binary=True)
        self.preprocess = simple_preprocess
    def _generate_batch_embeddings(self, texts: List[str], batch_size: int) -> np.ndarray:
        embeddings = []
        for text in texts:
            words = self.preprocess(text)
            try:
                doc_vector = self.model.get_mean_vector(
                    words,
                    ignore_missing=True,
                )
                embeddings.append(doc_vector)
            except (KeyError, ValueError):
                embeddings.append(np.zeros(self.model.vector_size))
        return np.array(embeddings)
class MistralEmbeddingGenerator(EmbeddingGenerator):
    def __init__(self, cache_dir: str, dataset_name: str, batch_size: int, api_key: str):
        super().__init__("Mistral", dataset_name, cache_dir, None, batch_size)
        self.api_key = api_key
        self.model_name = "mistral-embed"
        self.client = Mistral(api_key=self.api_key)
        self.max_tokens = 8192
    def _generate_batch_embeddings(self, texts: List[str], batch_size: int) -> np.ndarray:
        import time
        import hashlib
        embeddings = []
        cache_key_raw = (texts[0][:10] if texts else "") + str(len(texts)) + str(batch_size)
        cache_key = hashlib.md5(cache_key_raw.encode('utf-8')).hexdigest()
        cache_folder = os.path.join(self.cache_dir, cache_key)
        os.makedirs(cache_folder, exist_ok=True)
        for i in tqdm.tqdm(range(0, len(texts), batch_size), desc="Generating batch embeddings (Mistral)"):
            batch_texts = texts[i:i + batch_size]
            batch_texts = [text[:self.max_tokens] for text in batch_texts]
            cache_file = os.path.join(cache_folder, f"{i}.npy")
            if os.path.exists(cache_file):
                batch_emb = np.load(cache_file)
                embeddings.append(batch_emb)
                continue
            response = self.client.embeddings.create(
                inputs=[text[:self.max_tokens] for text in batch_texts],
                model=self.model_name
            )
            time.sleep(2)
            batch_emb = np.array([r.embedding for r in response.data])
            np.save(cache_file, batch_emb)
            embeddings.append(batch_emb)
        return np.vstack(embeddings)
class OpenAIEmbeddingGenerator(EmbeddingGenerator):
    def __init__(self, cache_dir: str, dataset_name: str, device: torch.device, batch_size: int, api_key: str):
        super().__init__("OpenAI", dataset_name, cache_dir, device, batch_size)
        self.api_key = api_key
        self.model_name = "text-embedding-3-small"
        openai.api_key = self.api_key
        self.max_tokens = 8192 * 4
    def _generate_batch_embeddings(self, texts: List[str], batch_size: int) -> np.ndarray:
        import time
        embeddings = []
        for i in tqdm.tqdm(range(0, len(texts), batch_size), desc="Generating batch embeddings (OpenAI)"):
            batch_texts = texts[i:i + batch_size]
            batch_texts_fill_empty = [text[:self.max_tokens] if text != "" else " " for text in batch_texts]
            response = openai.embeddings.create(input=batch_texts_fill_empty, model=self.model_name)
            time.sleep(3)
            embeddings.append(np.array([r.embedding for r in response.data]))
        return np.vstack(embeddings)
class GTEEmbeddingGenerator(EmbeddingGenerator):
    def __init__(self, cache_dir: str, dataset_name: str, device: torch.device, batch_size: int):
        super().__init__("GTE", dataset_name, cache_dir, device, batch_size)
        self.model = SentenceTransformer("Alibaba-NLP/gte-Qwen2-7B-instruct", trust_remote_code=True)
        self.model.max_seq_length = 8192
    def _generate_batch_embeddings(self, texts: List[str], batch_size: int) -> np.ndarray:
        return self.model.encode(texts, batch_size=batch_size)
    def _generate_query_embeddings(self, texts: List[str], batch_size: int) -> np.ndarray:
        return self.model.encode(texts, batch_size=batch_size, prompt_name="query")
class BGEEmbeddingGenerator(EmbeddingGenerator):
    def __init__(self, cache_dir: str, dataset_name: str, device: torch.device, batch_size: int):
        super().__init__("BGE", dataset_name, cache_dir, device, batch_size)
        examples = [
        {'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
        'query': 'what is a virtual interface',
        'response': "A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes."},
        {'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
        'query': 'causes of back pain in female for a week',
        'response': "Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management."}
        ]
        self.model = FlagICLModel('BAAI/bge-en-icl', 
                            query_instruction_for_retrieval="Given a web search query, retrieve relevant passages that answer the query.",
                            examples_for_task=examples,
                            cache_dir=self.pretrained_weights_dir,
                            use_fp16=False)
    def _generate_batch_embeddings(self, texts: List[str], batch_size: int) -> np.ndarray:
        return self.model.encode_corpus(texts, batch_size=batch_size, convert_to_numpy=True)
    def _generate_query_embeddings(self, texts: List[str], batch_size: int) -> np.ndarray:
        return self.model.encode_queries(texts, batch_size=batch_size, convert_to_numpy=True)
class GloVeEmbeddingGenerator(EmbeddingGenerator):
    def __init__(self, cache_dir: str, dataset_name: str, device: torch.device, batch_size: int, dim: int = 300):
        super().__init__("GloVe", dataset_name, cache_dir, device, batch_size)
        if dim not in [50, 100, 200, 300]:
            raise ValueError("GloVe dimension must be one of: 50, 100, 200, 300")
        model_path = os.path.expanduser(f"~/glove.6B/glove.6B.{dim}d.txt")
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"GloVe vectors not found at {model_path}. Please ensure glove.6B.zip is extracted to ~/glove.6B/")
        self.model = {}
        logger.info(f"Loading GloVe {dim}d vectors...")
        with open(model_path, 'r', encoding='utf-8') as f:
            for line in tqdm.tqdm(f):
                values = line.split()
                word = values[0]
                vector = np.asarray(values[1:], dtype='float32')
                self.model[word] = vector
        self.vector_size = dim
        self.preprocess = simple_preprocess
        logger.info(f"Loaded {len(self.model)} word vectors of dimension {dim}")
    def _generate_batch_embeddings(self, texts: List[str], batch_size: int) -> np.ndarray:
        embeddings = []
        for text in texts:
            words = self.preprocess(text)
            vectors = []
            for word in words:
                if word in self.model:
                    vectors.append(self.model[word])
            if vectors:
                doc_vector = np.mean(vectors, axis=0)
            else:
                doc_vector = np.zeros(self.vector_size)
            embeddings.append(doc_vector)
        return np.array(embeddings)
def get_embedding_generator(model_name: str, dataset_name: str, cache_dir: str, device: torch.device, batch_size: int, **kwargs):
    if model_name.lower() == "nv-embed":
        return NVEmbedEmbeddingGenerator(cache_dir, dataset_name, device, batch_size)
    elif model_name.lower() == "fast-text":
        return FastTextEmbeddingGenerator(cache_dir, dataset_name, device, batch_size)
    elif model_name.lower() == "word2vec":
        return Word2VecEmbeddingGenerator(cache_dir, dataset_name, device, batch_size)
    elif model_name.lower() == "mistral":
        return MistralEmbeddingGenerator(cache_dir, dataset_name, batch_size, os.getenv("MISTRAL_API_KEY"))
    elif model_name.lower() == "openai" or model_name.lower() == "gpt3":
        return OpenAIEmbeddingGenerator(cache_dir, dataset_name, device, batch_size, os.getenv("OPENAI_API_KEY"))
    elif model_name.lower() == "bge":
        return BGEEmbeddingGenerator(cache_dir, dataset_name, device, batch_size)
    elif model_name.lower() == "gte":
        return GTEEmbeddingGenerator(cache_dir, dataset_name, device, batch_size)
    elif model_name.lower() == "glove":
        dim = kwargs.get('dim', 300)
        return GloVeEmbeddingGenerator(cache_dir, dataset_name, device, batch_size, dim=dim)
    else:
        raise ValueError(f"Invalid model name: {model_name}")
