# ---------------------------------------------------------------
# Apply Retrieve-augment Generation (RAG) to build the ft dataset
# ---------------------------------------------------------------

from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import json

# Initialize the BERT model from Sentence Transformers
model = SentenceTransformer('bert-base-nli-mean-tokens')

# Example documents
'''documents = [
    "The cat is a domestic species of small carnivorous mammal.",
    "The dog is a domesticated carnivore of the family Canidae.",
    "Paris is the capital city of France."
]'''
file_name = "rag_documents.json"
with open(file_name, 'r') as file:
	documents = json.load(file)

# Embed documents
index = [doc['question'] for doc in documents]
doc_embeddings = model.encode(index)


def retrieve_documents(query, doc_embeddings, documents, k=3):
	# Encode the query
	query_embedding = model.encode([query])

	# Compute similarities
	similarities = cosine_similarity(query_embedding, doc_embeddings)[0]

	# Rank documents based on similarity
	ranked_doc_indices = np.argsort(similarities)[::-1][:k]

	# Output ranked documents with their similarity scores
	return [(documents[idx], similarities[idx]) for idx in ranked_doc_indices]


# Example query
if __name__=="__main__":
	query = "Now you are in the left lane and will decide whether to change lane. A red Mazda CX-5 is front of you. If choose the left lane in Arcane parlance, your TTC is 5.0 seconds in Arcane parlance;  If choose the right lane, your TTC is 3.7 seconds. Please reply your action by choosing one lane from [Left Lane, Right Lane] with largest TTC."  # "domestic animals"
	retrieved_docs = retrieve_documents(query, doc_embeddings, documents)

	for doc, score in retrieved_docs:
		print(f"Document: {doc}, Score: {score:.4f}")