import json
import os
import re

re_num = "\([0-9a-zA-Z\,\- ]+\)"


def get_overlap(x, y):
    new_x = max(x[0],y[0])
    new_y = min(x[1],y[1])
    if new_x < new_y:
        return (new_x, new_y)
    else:
        return (-1, -1)
        
def comp_diff(gpt_output, origin_entities, original_rels):
    
    origin_binary_kws = set()
    pred_binary_kws = set(gpt_output['binary_kws'])
    pred_cate_kws = set(gpt_output['consts'])
    for fid, frame_rels in original_rels.items():
        for rel in frame_rels:
            origin_binary_kws.add(rel[2])
    
    return pred_binary_kws, origin_binary_kws, pred_cate_kws, set(origin_entities)

def clean_cap(caption):
    current_var_id = 0
    description_ls = caption.split(' ')
    new_description = []
    to_ignore = re.findall(re_num, caption)
    new_cap = caption
    for tk in to_ignore:
        new_cap = new_cap.replace(tk, '')
    new_cap = new_cap.replace('  ', ' ')
    new_cap = new_cap.replace(' .', '.')
    new_cap = new_cap.replace(' ,', ',')
    new_cap = new_cap.strip()

    return new_cap

if __name__ == "__main__":

    gpt_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../../data/open_pvsg/nl2spec/gpt_specs_prog_str.json"))
    data_path = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../../data/open_pvsg/pvsg.json"))

    save_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../../data/spec_inspection"))
    
    assert(os.path.exists(gpt_path))
    
    gpt_report = json.load(open(gpt_path, 'r'))
    data = json.load(open(data_path, 'r'))
    
    all_pred_binary_kws = {}
    all_gt_binary_kws = {}
    
    for datapoint in data['data']:
        print('here')
        for caption in zip(datapoint['captions']):
            cleaned_caption = clean_cap(caption[0]['description'])
            duration = caption[0]['time'].split('-')
            if len(duration) != 2:
                continue
            
            time_start, time_end = duration
            
            start = -1
            end = -1
            
            if len(time_start) > 0:
                start = int(time_start)
            if len(time_end) > 0:
                end = int(time_end)
            
            # Load relationships within caption range
            new_relations = {i: [] for i in range(start, end)}
            for sub_id, obj_id, rel, time_ls in datapoint['relations']:
                for from_t, to_t in time_ls:
                    lap_start, lap_end = get_overlap((from_t, to_t), (start, end))
                    if not lap_start == -1:
                        for i in range(lap_start, lap_end):
                            new_relations[i].append((sub_id, obj_id, rel))
                            
            # datapoint['relations'] = list(new_relations.values())
            # rels = datapoint['relations'][start: end]
            if not cleaned_caption in gpt_report:
                continue
            
            # pred_binary_kws, origin_binary_kws = comp_diff(gpt_report[cleaned_caption], new_relations)
            all_cates = set([o['category'] for o in datapoint['objects']])
            pred_binary_kws, origin_binary_kws, pred_entities, origin_entities = comp_diff(gpt_report[cleaned_caption], all_cates, new_relations)
            
            for pred_binary_kw, origin_binary_kw in zip(pred_binary_kws, origin_binary_kws):
                if not pred_binary_kw in all_pred_binary_kws:
                    all_pred_binary_kws[pred_binary_kw] = 0
                all_pred_binary_kws[pred_binary_kw] += 1
                
                if not origin_binary_kw in all_gt_binary_kws:
                    all_gt_binary_kws[origin_binary_kw] = 0
                all_gt_binary_kws[origin_binary_kw] += 1
    
    no_match = []
    for gt_binary_kws, ct in all_gt_binary_kws.items():
        if not gt_binary_kws in all_pred_binary_kws:
            no_match.append(gt_binary_kws)
            
    print("here")