import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from typing import Any, Dict, List, Optional
from langchain.pydantic_v1 import BaseModel, Extra, Field
from langchain.schema.embeddings import Embeddings
import numpy as np
from tqdm import tqdm

from faiss_customized import FAISS
from langchain.docstore.document import Document

def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
        

class TransformerEmbeddings(BaseModel, Embeddings):
    """Transformer embedding models.
    """

    client: Any  #: :meta private:
    tokenizer: Any
    model_name: str = 'Salesforce/SFR-Embedding-Mistral'
    """Model name to use."""
    cache_folder: Optional[str] = "/data/user_data/  l5/.cache"
    """Path to store models."""
    model_kwargs: Dict[str, Any] = Field(default_factory=dict)
    """Keyword arguments to pass to the model."""
    encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
    """Keyword arguments to pass when calling the `encode` method of the model."""

    def __init__(self, **kwargs: Any):
        """Initialize the sentence_transformer."""
        super().__init__(**kwargs)

        self.tokenizer = AutoTokenizer.from_pretrained('Salesforce/SFR-Embedding-Mistral', cache_dir=self.cache_folder)
        self.client = AutoModel.from_pretrained('Salesforce/SFR-Embedding-Mistral', cache_dir=self.cache_folder)
        self.client.to("cuda")
        self.tokenizer.add_eos_token = True
        

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Compute doc embeddings using a HuggingFace transformer model.

        Args:
            texts: The list of texts to embed.

        Returns:
            List of embeddings, one for each text.
        """
        max_length = 4096
        
        new_embeddings = []
        for text in texts:
            batch_dict = self.tokenizer([text], max_length = max_length - 1, padding=True, truncation=True, return_tensors="pt")
            batch_dict['input_ids'] = batch_dict['input_ids'].to("cuda")
            batch_dict['attention_mask'] = batch_dict['attention_mask'].to("cuda")
            outputs = self.client(**batch_dict)
            embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
            embeddings = F.normalize(embeddings, p=2, dim=1)
            new_embeddings.append(embeddings[0].tolist())
            print(embeddings.shape)
        return new_embeddings

    def embed_query(self, text: str) -> List[float]:
        """Compute query embeddings using a HuggingFace transformer model.

        Args:
            text: The text to embed.

        Returns:
            Embeddings for the text.
        """
        return self.embed_documents([text])[0]
    

class Store:
    def __init__(self, embed_model):
        self.embed_model = embed_model
        self.docs = {}
        self.doc_sizes = {}
    
    def load_doc(self, doc_name, doc_lst):
        self.docs[doc_name] = FAISS.from_documents(
            [Document(page_content=i, metadata={'doc_name': doc_name, 'index': idx, 'content': i}) for idx, i in enumerate(doc_lst)],
            self.embed_model
        )
        self.doc_sizes[doc_name] = len(doc_lst)
    
    def load_embeddings(self, doc_name, path):
        self.docs[doc_name] = FAISS.load_local(path, self.embed_model)
    
    def get_similarity_score(self, doc_name: str, query: str):
        assert doc_name in self.docs
        assert doc_name in self.doc_sizes
        docs = self.docs[doc_name]._similarity_search_with_relevance_scores(query, k=self.doc_sizes[doc_name])
        return docs
        
    def sort_samples(self, doc_name1, doc2_lst):
        avg_scores = []
        for idx, doc_2 in tqdm(enumerate(doc2_lst)):
            doc_w_scores = self.get_similarity_score(doc_name1, doc_2)
            avg_score = np.mean([score for doc, score in doc_w_scores])
            avg_scores.append((avg_score, idx))
        
        avg_scores = sorted(avg_scores, key=lambda x: x[0])
        return avg_scores
