import json
import pickle
from utils.logic.fol import *

def parse_triple(triple):
    if triple.count(',') == 1:
        predicate = triple.split('(')[0].strip()
        arg1, arg2 = triple.split('(')[1].split(')')[0].split(',')
        arg1 = arg1.strip(); arg2 = arg2.strip()
    elif triple.count(',') == 0:
        predicate = triple.split('(')[0].strip()
        arg1 = triple.split('(')[1].split(')')[0].strip()
        arg2 = 'None'
    return predicate, arg1, arg2

if __name__ == "__main__":

    dataset = "onto"
    if dataset == "geo":

        with open("data/t-box.json", "r") as f:
            t_box = json.load(f)

        parsed_tbox = []

        for i, entry in enumerate(t_box):
            _, fol = parse_fol(entry)

            parsed_tbox.append(fol)

        with open("data/geo_typed_tbox.pkl", "wb") as f:
            pickle.dump(parsed_tbox, f)


        # load the geo-data dataset to form the abox
        with open("data/geo_data.json", "r") as f:
            data = json.load(f)

        typed_abox = []
        abox = []

        updated_data = []



        for i, entry in enumerate(data):
            triples = entry["triples"]
            types = entry["types"]
            try:
                for triple in triples:
                    if triple in abox:
                        continue

                    predicate = triple.split('(')[0].strip()
                    arg1, arg2 = triple.split('(')[1].split(')')[0].split(',')
                    arg1 = arg1.strip()
                    arg2 = arg2.strip()
                    arg1_type = types[arg1]
                    arg2_type = types[arg2]
                    typed_abox.append(Predicate(predicate, Constant(arg1, arg1_type), Constant(arg2, arg2_type)))
                    abox.append(triple)

                updated_data.append(entry)
            except:
                continue
        
        with open("data/geo_typed_abox.pkl", "wb") as f:
            pickle.dump(typed_abox, f)

        with open("data/updated_geo_data.json", "w", encoding='utf-8') as f:
            json.dump(updated_data, f, indent=4, ensure_ascii=False)
        
        with open("data/abox-geo.json", "w", encoding='utf-8') as f:
            json.dump(abox, f, indent=4, ensure_ascii=False)

    elif dataset == "onto":
        with open("data/t-box-onto.json", "r", encoding= 'utf-8') as f:
            t_box = json.load(f)

        parsed_tbox = []

        for i, entry in enumerate(t_box):
            _, fol = parse_fol(entry)

            parsed_tbox.append(fol)

        with open("data/onto_typed_tbox.pkl", "wb") as f:
            pickle.dump(parsed_tbox, f)

        
        # load the onto-data dataset to form the abox
        with open("data/onto_data.json", "r", encoding= 'utf-8') as f:
            data = json.load(f)
        
        typed_abox = []
        abox = []

        updated_data = []
        for i, entry in enumerate(data):
            triples = entry["triples"]
            types = entry["types"]
            try:
                for triple in triples:
                    if triple in abox:
                        continue

                    predicate, arg1, arg2 = parse_triple(triple)
                    arg1 = arg1.strip()
                    arg2 = arg2.strip()
                    arg1_type = types[arg1]
                    if arg2 != 'None':
                        arg2_type = types[arg2]
                        typed_abox.append(Predicate(predicate, Constant(arg1, arg1_type), Constant(arg2, arg2_type)))
                    else:
                        typed_abox.append(Predicate(predicate, Constant(arg1, arg1_type)))
                    abox.append(triple)

                updated_data.append(entry)
            except:
                continue
        
        with open("data/onto_typed_abox.pkl", "wb") as f:
            pickle.dump(typed_abox, f)

        with open("data/updated_onto_data.json", "w", encoding='utf-8') as f:
            json.dump(updated_data, f, indent=4, ensure_ascii=False)
        
        with open("data/abox-onto.json", "w", encoding='utf-8') as f:
            json.dump(abox, f, indent=4, ensure_ascii=False)