from utils import all_entities, all_binary_preds, clean_cap

import spacy
import json
import os

nlp = spacy.load('en_core_web_lg')

all_entity_rep = [nlp(e) for e in all_entities]
all_binary_rep = [nlp(e) for e in all_binary_preds]

def get_neg_examples_helper(candidate_reps, kw_reps, threshold):

    all_neg_sim = []
    all_sim = []

    for entity_rep in candidate_reps:

        is_sim = False
        for unary_rep in kw_reps:
            unary_sim_score = entity_rep.similarity(unary_rep)
            if unary_sim_score > threshold:
                is_sim = True

        if is_sim:
            all_sim.append(entity_rep)
        else:
            all_neg_sim.append(entity_rep)

    return all_sim, all_neg_sim

def get_negative_examples(spec, cates, bin_kw):

    entity_kw_reps = [nlp(kw) for kw in list(set(spec['consts']))]
    binary_kw_reps = [nlp(kw) for kw in list(set(spec['binary_kws']))]

    _, all_neg_entity_example = get_neg_examples_helper(all_entity_rep, entity_kw_reps, threshold=0.6)
    _, all_neg_binary_example = get_neg_examples_helper(all_binary_rep, binary_kw_reps, threshold=0.6)

    assert len(set(cates).intersection(set(all_neg_entity_example))) == 0
    assert len(set(bin_kw).intersection(set(all_neg_binary_example))) == 0

    return  [e.text for e in all_neg_entity_example], [b.text for b in all_neg_binary_example]

if __name__ == "__main__":
    print('here')
    dataset = "open_pvsg"


    data_file_name = 'pvsg.json'
    # negative_file_name = "videollamav2_neg_examples.json"
    # cache_file_name = f"gpt_specs_videollamav2_prog_str.json"

    cache_file_name = f"gpt_specs_prog_str.json"
    negative_file_name = "neg_examples.json"


    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../../data/{dataset}"))
    data_nl_dir = os.path.join(data_dir, 'nl2spec')

    assert (os.path.exists(data_dir))
    if not os.path.exists(data_nl_dir):
        os.mkdir(data_nl_dir)

    cache_path = os.path.join(data_nl_dir, cache_file_name)
    data_path = os.path.join(data_dir, data_file_name)
    negative_path = os.path.join(data_dir, negative_file_name)

    video_save_dir = os.path.join(data_dir, 'pred_video')
    model_dir = os.path.join(data_dir, 'model')

    gpt_cache = json.load(open(cache_path, 'r'))
    anno = json.load(open(data_path, 'r'))
    data_split_info = anno['split']
    data_lookup = {data_dict['video_id']: data_dict for data_dict in anno['data']}
    full_dataset = []

    to_inspect = "The person closes the lid of the bottle."

    for phase in ["train", "val"]:
        for dataset_name, data_split in data_split_info.items():
            # dataset.append(data_split)
            for data_id in data_split[phase]:

                for caption in data_lookup[data_id]['captions']:

                    clean_des = clean_cap(caption['description'])
                    if not clean_des in gpt_cache:
                        continue

                    all_cates = set()
                    for cates in data_lookup[data_id]['objects']:
                        all_cates.add(cates['category'])

                    binary_kws = set()
                    for binary_kw in data_lookup[data_id]['relations']:
                        binary_kws.add(binary_kw[2])

                    datapoint = {'data_id': data_id,
                                 'caption': caption,
                                 'gpt_spec': gpt_cache[clean_des],
                                 'dataset': dataset_name,
                                 'objects': data_lookup[data_id]['objects'],
                                 'meta': data_lookup[data_id]['meta'],
                                 'relations': data_lookup[data_id]['relations'],
                                 'cates': all_cates,
                                 'binary_kws': binary_kws,
                                 'clean_des': clean_des}

                    full_dataset.append(datapoint)

    if os.path.exists(negative_path):
        negative_examples = json.load(open(negative_path, 'r'))
    else:
        negative_examples = {}
    for datapoint in full_dataset:
        if datapoint['clean_des'] in negative_examples:
            continue

        neg_entity, neg_binary = get_negative_examples(datapoint['gpt_spec'], datapoint['cates'], datapoint['binary_kws'])
        negative_examples[datapoint['clean_des']] = {}
        negative_examples[datapoint['clean_des']]['neg_entity'] = neg_entity
        negative_examples[datapoint['clean_des']]['neg_binary'] = neg_binary

    json.dump(negative_examples, open(negative_path, 'w'))
    print('here')