# import sentence_transformers 
from sentence_transformers import SentenceTransformer
import numpy as np
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class xiaobuRetrieval():
    def __init__(self):
        
        self.model = SentenceTransformer('../checkpoints/lier007__xiaobu-embedding-v2')
        self.column_data = self.read_and_process_file()
        embedding_file_path = '../RAG/embeddings_2.npy'

        
        if os.path.exists(embedding_file_path):
            # 如果文件存在，加载 embeddings_2
            embeddings_2 = np.load(embedding_file_path)
            print("Loaded embeddings_2 from file.")
        else:
            # 如果文件不存在，计算 embeddings_2 并保存
            embeddings_2 = self.model.encode(self.column_data, normalize_embeddings=True)
            np.save(embedding_file_path, embeddings_2)
            print("Computed and saved embeddings_2.")
        self.embeddings_2 = embeddings_2

    def read_and_process_file(self):
      result = []
      file_path = "../data/MEDQA/textbooks/zh_paragraph/all_books.txt"
      with open(file_path, 'r', encoding='utf-8') as file:
          lines = file.readlines()
      result = [line.strip() for line in lines if len(line.strip()) >= 30]

      return result    
    
    def __call__(self,query,top_k = 3):
        embeddings_1 = self.model.encode([query], normalize_embeddings=True)
     
        similarity = embeddings_1 @ self.embeddings_2.T
        similarity = similarity.flatten()
    
        top_indices = np.argsort(similarity)[-top_k:][::-1]
        
        top_results = [(self.column_data[i], similarity[i]) for i in top_indices]
        
        text = "检索知识:"
        # 遍历 top5_results 并按需要格式化
        for i, (result, sim) in enumerate(top_results):
            index = top_indices[i]
            text += self.column_data[index]+"。"
        # results.append(text)
        return text

