from utils import all_entities, all_binary_preds, clean_cap

import spacy 
import json
import os

nlp = spacy.load('en_core_web_lg')

all_entity_rep = [nlp(e) for e in all_entities]
all_binary_rep = [nlp(e) for e in all_binary_preds]

def get_neg_examples_helper(candidate_reps, kw_reps, threshold):
    
    all_neg_sim = []
    all_sim = []
    
    for entity_rep in candidate_reps:
        
        is_sim = False
        for unary_rep in kw_reps:
            unary_sim_score = entity_rep.similarity(unary_rep)
            if unary_sim_score > threshold:
                is_sim = True
                
        if is_sim:
            all_sim.append(entity_rep)
        else:
            all_neg_sim.append(entity_rep)
            
    return all_sim, all_neg_sim

def get_negative_examples(spec, cates, bin_kw):

    entity_kw_reps = [nlp(kw) for kw in list(set(spec['consts']))]
    binary_kw_reps = [nlp(kw) for kw in list(set(spec['binary_kws']))]
    
    _, all_neg_entity_example = get_neg_examples_helper(all_entity_rep, entity_kw_reps, threshold=0.6)
    _, all_neg_binary_example = get_neg_examples_helper(all_binary_rep, binary_kw_reps, threshold=0.6)
    
    assert len(set(cates).intersection(set(all_neg_entity_example))) == 0
    assert len(set(bin_kw).intersection(set(all_neg_binary_example))) == 0

    return  [e.text for e in all_neg_entity_example], [b.text for b in all_neg_binary_example]

if __name__ == "__main__":
    print('here')
    dataset = "open_pvsg"
    
    cache_file_name = f"gpt_specs_scl.json"
    data_file_name = 'pvsg.json'
    negative_file_name = "neg_examples.json"
    
    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../data/{dataset}"))
    data_nl_dir = os.path.join(data_dir, 'nl2spec')
    
    assert (os.path.exists(data_dir))
    if not os.path.exists(data_nl_dir):
        os.mkdir(data_nl_dir)

    cache_path = os.path.join(data_nl_dir, cache_file_name)
    data_path = os.path.join(data_dir, data_file_name)
    negative_path = os.path.join(data_dir, negative_file_name)
    
    scl_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../scl'))
    assert os.path.exists(scl_dir)
    scl_path = os.path.join(scl_dir, f'{dataset}_ltl.scl')
    
    video_save_dir = os.path.join(data_dir, 'pred_video')
    model_dir = os.path.join(data_dir, 'model')

    gpt_cache = json.load(open(cache_path, 'r'))
    anno = json.load(open(data_path, 'r'))
    data_split_info = anno['split']
    data_lookup = {data_dict['video_id']: data_dict for data_dict in anno['data']}
    full_dataset = []
    
    for phase in ["train", "val"]:
        for dataset_name, data_split in data_split_info.items():
            # dataset.append(data_split)
            for data_id in data_split[phase]:    
                # if data_id in checked_datapoints:
                #     continue
                # if not "1006_4580824633" == data_id:
                #     continue
                
                for caption in data_lookup[data_id]['captions']:
                    
                    clean_des = clean_cap(caption['description'])
                    if not clean_des in gpt_cache:
                        continue
                    
                    all_cates = set()
                    for cates in data_lookup[data_id]['objects']:
                        all_cates.add(cates['category'])
                    
                    binary_kws = set()
                    for binary_kw in data_lookup[data_id]['relations']:
                        binary_kws.add(binary_kw[2])
                        
                    datapoint = {'data_id': data_id, 
                                 'caption': caption, 
                                 'gpt_spec': gpt_cache[clean_des],
                                 'dataset': dataset_name, 
                                 'objects': data_lookup[data_id]['objects'], 
                                 'meta': data_lookup[data_id]['meta'], 
                                 'relations': data_lookup[data_id]['relations'],
                                 'cates': all_cates,
                                 'binary_kws': binary_kws,
                                 'clean_des': clean_des}
                    
                    full_dataset.append(datapoint)
                  
    negative_examples = {}           
    for datapoint in full_dataset:
        neg_entity, neg_binary = get_negative_examples(datapoint['gpt_spec'], datapoint['cates'], datapoint['binary_kws'])
        negative_examples[datapoint['clean_des']] = {}
        negative_examples[datapoint['clean_des']]['neg_entity'] = neg_entity
        negative_examples[datapoint['clean_des']]['neg_binary'] = neg_binary
        
    json.dump(negative_examples, open(negative_path, 'w'))
    print('here')
