from utils.embeddings_utils import *
from dotenv import load_dotenv
import pickle, json
from openai import OpenAI
from sentence_transformers import SentenceTransformer, util


def rules2nl(rules):
    nl2rules = {}
    rules_nl = []
    for rule in rules:
        nl = str(rule)
        nl2rules[nl] = rule
        rules_nl.append(nl)
    return rules_nl, nl2rules

def encode(s, model):
    if model == 'gpt3':
        load_dotenv()
        client = OpenAI()
        embeddings = client.embeddings.create(input=s, model= "text-embedding-3-small").data

    elif model == 'st':
        mpnet = SentenceTransformer(
            "sentence-transformers/all-mpnet-base-v2",
            cache_folder="models/sentence_transformers",
        )
        embeddings = mpnet.encode(s, normalize_embeddings=True, convert_to_tensor=True)
    return embeddings


def cache_tbox_embeddings_gpt3(dataset):
    if dataset == "geo":
    
        with open("data/geo_typed_tbox.pkl", "rb") as f:
            t_box = pickle.load(f)
    elif dataset == "onto":
        with open("data/onto_typed_tbox.pkl", "rb") as f:
            t_box = pickle.load(f)

    rules_nl, nl2rules = rules2nl(t_box)

    t_box_embeddings = encode(rules_nl, 'gpt3')


    return t_box_embeddings, nl2rules, rules_nl

def cache_tbox_embeddings_st(dataset):
    if dataset == "geo":
        with open("data/geo_typed_tbox.pkl", "rb") as f:
            t_box = pickle.load(f)
    elif dataset == "onto":
        with open("data/onto_typed_tbox.pkl", "rb") as f:
            t_box = pickle.load(f)

    rules_nl, nl2rules = rules2nl(t_box)

    t_box_embeddings = encode(rules_nl, 'st')

    return t_box_embeddings, nl2rules, rules_nl
    

def cache_abox_embeddings_gpt3(dataset):
    if dataset == "geo":
        with open("data/geo_typed_abox.pkl", "rb") as f:
            a_box = pickle.load(f)
    elif dataset == "onto":
        with open("data/onto_typed_abox.pkl", "rb") as f:
            a_box = pickle.load(f)

    rules_nl, nl2rules = rules2nl(a_box)
    a_box_embeddings = encode(rules_nl, 'gpt3')

    return a_box_embeddings, nl2rules, rules_nl
    
def cache_abox_embeddings_st(dataset):
    if dataset == "geo":
        with open("data/geo_typed_abox.pkl", "rb") as f:
            a_box = pickle.load(f)
    elif dataset == "onto":
        with open("data/onto_typed_abox.pkl", "rb") as f:
            a_box = pickle.load(f)

    rules_nl, nl2rules = rules2nl(a_box)
    a_box_embeddings = encode(rules_nl, 'st')

    return a_box_embeddings, nl2rules, rules_nl



if __name__ == "__main__":

    dataset = "onto"
    #t_box_embeddings_gpt3, tbox_nl2rules, tbox_rules_nl = cache_tbox_embeddings_gpt3(dataset)
    t_box_embeddings_st, tbox_nl2rules, tbox_rules_nl = cache_tbox_embeddings_st(dataset)
    #a_box_embeddings_gpt3, abox_nl2rules, abox_rules_nl = cache_abox_embeddings_gpt3(dataset)
    a_box_embeddings_st, abox_nl2rules, abox_rules_nl = cache_abox_embeddings_st(dataset)



    # with open("data/{dataset}_tbox_embeddings_gpt3.pkl", "wb") as f:
    #     pickle.dump(t_box_embeddings_gpt3, f)
    with open(f"data/{dataset}_tbox_embeddings_st.pkl", "wb") as f:
        pickle.dump(t_box_embeddings_st, f)
    # with open(f"data/{dataset}_abox_embeddings_gpt3.pkl", "wb") as f:
    #     pickle.dump(a_box_embeddings_gpt3, f)
    with open(f"data/{dataset}_abox_embeddings_st.pkl", "wb") as f:
        pickle.dump(a_box_embeddings_st, f)
    with open(f"data/{dataset}_tbox_nl2rules.pkl", "wb") as f:
        pickle.dump(tbox_nl2rules, f)
    with open(f"data/{dataset}_abox_nl2rules.pkl", "wb") as f:
        pickle.dump(abox_nl2rules, f)
    with open(f"data/{dataset}_tbox_rules_nl.pkl", "wb") as f:
        pickle.dump(tbox_rules_nl, f)
    with open(f"data/{dataset}_abox_rules_nl.pkl", "wb") as f:
        pickle.dump(abox_rules_nl, f)
