from utils import all_entities, all_binary_preds, clean_cap, all_gpt_binary_kws

import spacy
import json
import os
import pickle
import numpy as np

from openai import OpenAI
from sklearn.metrics.pairwise import cosine_similarity
client = OpenAI()

def get_embedding(texts, model="text-embedding-3-small", save_path=None):
    texts = [text.replace("\n", " ") for text in texts]
    existing_embeddings = dict()
    if not save_path is None and os.path.exists(save_path):
        existing_embeddings = pickle.load(open(save_path, 'rb'))
    to_query_text = [t for t in texts if not t in existing_embeddings]
    
    if len(to_query_text) == 0:
        return existing_embeddings
    
    print(f"quering: {to_query_text}")
    embeddings = client.embeddings.create(input = to_query_text, model=model).data
    result = {t: e for t, e in zip(to_query_text, embeddings)}
    result.update(existing_embeddings)
    if not save_path is None:
       pickle.dump(result, open(save_path, 'wb'))
    return result


def get_neg_examples_helper(candidate_reps, kw_reps, topk=15):

    all_neg_sim = []
    all_sim = []
    if len(kw_reps) == 0:
        return [], [(k, -1) for k in candidate_reps]
    
    unary_sim_score = cosine_similarity(list(candidate_reps.values()), list(kw_reps.values()))
    aggr_unary_sim_score = np.max(unary_sim_score, axis = 1)
    threshold = aggr_unary_sim_score[(-aggr_unary_sim_score).argsort()[topk]]
    for gt_id, entity_rep in enumerate(candidate_reps):

        is_sim = False
        if aggr_unary_sim_score[gt_id] >= threshold:
            is_sim = True

        if is_sim:
            all_sim.append((entity_rep, aggr_unary_sim_score[gt_id]))
        else:
            all_neg_sim.append((entity_rep, aggr_unary_sim_score[gt_id]))

    return all_sim, all_neg_sim

def get_negative_examples(bin_kw, gpt_binary_kws, gt_binary_embs, gpt_binary_embs, topk=15):
    
    current_binary_reps = {kw: gpt_binary_embs[kw.replace('_', ' ')].embedding for kw in gpt_binary_kws}
    all_binary_sim, all_neg_binary_example = get_neg_examples_helper(gt_binary_embs, current_binary_reps, topk=topk)
    top_gpt_kws = set([t[0] for t in all_binary_sim])
    overlap_rate = len(bin_kw.intersection(top_gpt_kws)) / len(bin_kw)
    
    return overlap_rate

if __name__ == "__main__":
    print('here')
    dataset = "open_pvsg"


    data_file_name = 'pvsg.json'
    # negative_file_name = "videollamav2_neg_examples.json"
    # cache_file_name = f"gpt_specs_videollamav2_prog_str.json"

    cache_file_name = f"gpt_specs_prog_str.json"
    negative_file_name = "neg_examples.json"


    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../data/{dataset}"))
    data_nl_dir = os.path.join(data_dir, 'nl2spec')
    gpt_bin_save_path = os.path.join(data_nl_dir, "gpt_binary_kws_embs.pkl")
    gt_bin_save_path = os.path.join(data_nl_dir, "gt_binary_kws_embs.pkl")
    
    gpt_binary_embs = get_embedding(all_gpt_binary_kws, model="text-embedding-3-large", save_path=gpt_bin_save_path)
    gt_binary_embs = get_embedding(all_binary_preds, model="text-embedding-3-large", save_path=gt_bin_save_path)
    gt_binary_embs = {k:v.embedding for k, v in gt_binary_embs.items()}
    
    assert (os.path.exists(data_dir))
    if not os.path.exists(data_nl_dir):
        os.mkdir(data_nl_dir)

    cache_path = os.path.join(data_nl_dir, cache_file_name)
    data_path = os.path.join(data_dir, data_file_name)
    negative_path = os.path.join(data_dir, negative_file_name)

    video_save_dir = os.path.join(data_dir, 'pred_video')
    model_dir = os.path.join(data_dir, 'model')

    gpt_cache = json.load(open(cache_path, 'r'))
    anno = json.load(open(data_path, 'r'))
    data_split_info = anno['split']
    data_lookup = {data_dict['video_id']: data_dict for data_dict in anno['data']}
    full_dataset = []
    all_overlap_rates = []
    new_text = set()
    topk = 60
    for dataset_name, data_split in data_split_info.items():
        # dataset.append(data_split)
        for data_id in data_split['val']:

            for caption in data_lookup[data_id]['captions']:

                clean_des = clean_cap(caption['description'])
                if not clean_des in gpt_cache:
                    continue

                all_cates = set()
                for cates in data_lookup[data_id]['objects']:
                    all_cates.add(cates['category'])

                binary_kws = set()
                for binary_kw in data_lookup[data_id]['relations']:
                    binary_kws.add(binary_kw[2])

                gpt_spec = gpt_cache[clean_des]
                for binary_kw in gpt_spec['binary_kws']:
                    new_text.add(binary_kw.replace('_', ' '))
                for unary_kw in gpt_spec['unary_kws']:
                    new_text.add(unary_kw.replace('_', ' '))
                overlap_rate = get_negative_examples(binary_kws, gpt_spec['binary_kws'] + gpt_spec['unary_kws'], gt_binary_embs=gt_binary_embs, gpt_binary_embs=gpt_binary_embs, topk=topk)
                all_overlap_rates.append(overlap_rate)
                
    # get_embedding(new_text, model="text-embedding-3-large", save_path=gpt_bin_save_path)
    mean_overlap_rates = np.mean(all_overlap_rates)
    print('here')