import time
import pandas as pd
from sentence_transformers import SentenceTransformer, util
import json
import re
import os, torch
from toolbench.utils import standardize, standardize_category, change_name, process_retrieval_ducoment


class ToolRetriever:
    def __init__(self, corpus_tsv_path = "", model_path=""):
        self.corpus_tsv_path = corpus_tsv_path
        self.model_path = model_path
        self.model_name = model_path.split('/')[-1]
        self.corpus, self.corpus2tool = self.build_retrieval_corpus()
        self.embedder = self.build_retrieval_embedder()
        self.corpus_embeddings = self.build_corpus_embeddings()
        
    def build_retrieval_corpus(self):
        print("Building corpus...")
        documents_df = pd.read_csv(self.corpus_tsv_path, sep='\t')
        corpus, corpus2tool = process_retrieval_ducoment(documents_df)
        corpus_ids = list(corpus.keys())
        corpus = [corpus[cid] for cid in corpus_ids]
        return corpus, corpus2tool

    def build_retrieval_embedder(self):
        print("Building embedder...")
        embedder = SentenceTransformer(self.model_path)
        return embedder
    
    def build_corpus_embeddings(self):
        print("Building corpus embeddings with embedder...")
        embedding_save_path = self.corpus_tsv_path.replace('.tsv', f'_{self.model_name}_embeddings.pt')
        if os.path.exists(embedding_save_path):
            print("Loading pre-computed corpus embeddings...")
            corpus_embeddings = torch.load(embedding_save_path)
            return corpus_embeddings
        print("Computing corpus embeddings...")
        corpus_embeddings = self.embedder.encode(self.corpus, convert_to_tensor=True)

        torch.save(corpus_embeddings, embedding_save_path)
        return corpus_embeddings

    def retrieving(self, query, top_k=5, excluded_tools={}):
        print("Retrieving...")
        start = time.time()
        query_embedding = self.embedder.encode(query, convert_to_tensor=True)
        hits = util.semantic_search(query_embedding, self.corpus_embeddings, top_k=10*top_k, score_function=util.cos_sim)
        retrieved_tools = []
        for rank, hit in enumerate(hits[0]):
            # import pdb; pdb.set_trace()
            try:
                category, tool_name, api_name = self.corpus2tool[self.corpus[hit['corpus_id']]].split('[SEP]') 
            except:
                print(self.corpus2tool[self.corpus[hit['corpus_id']]])
                import pdb; pdb.set_trace()
            category = standardize_category(category)
            tool_name = standardize(tool_name) # standardizing
            api_name = change_name(standardize(api_name)) # standardizing
            if category in excluded_tools:
                if tool_name in excluded_tools[category]:
                    top_k += 1
                    continue
            tmp_dict = {
                "category": category,
                "tool_name": tool_name,
                "api_name": api_name
            }
            retrieved_tools.append(tmp_dict)
        return retrieved_tools