import json, os
from pyDatalog import pyDatalog
from utils.logic.fol import *
from pyDatalog.pyDatalog import assert_fact


dataset = "onto"

def parse_triple(triple):
    if triple.count(',') == 1:
        predicate = triple.split('(')[0].strip()
        arg1, arg2 = triple.split('(')[1].split(')')[0].split(',')
        arg1 = arg1.strip(); arg2 = arg2.strip()
    elif triple.count(',') == 0:
        predicate = triple.split('(')[0].strip()
        arg1 = triple.split('(')[1].split(')')[0].strip()
        arg2 = 'none'
    return predicate, arg1, arg2

with open(os.path.join(f'data/abox-{dataset}.json') , 'r', encoding= 'utf-8') as file:
    abox = json.load(file)

with open(os.path.join(f'data/t-box-{dataset}.json') , 'r', encoding= 'utf-8') as file:
    tbox = json.load(file)




pyDatalog.clear()
relations = set()
for fact in abox:
    predicate, arg1, arg2 = parse_triple(fact)
    arg1 = arg1.strip().lower(); arg2 = arg2.strip().lower()
    relations.add(predicate)
    relations.add(arg1)
    if arg2 != 'none':
        relations.add(arg2)

for axiom in tbox:
    rhs = axiom.split('=>')[1].strip()
    predicate = rhs.split('(')[0].strip()
    relations.add(predicate)
    lhs = axiom.split('=>')[0].strip()
    lhs = lhs.split(',', 1)[1].strip()
    lhs_conjs = lhs.split('&&')
    for lhs_conj in lhs_conjs:
        pred = lhs_conj.split('(')[0].strip()
        relations.add(pred)

string = 'A, B, C, D'

for relation in relations:
    string += ', ' + relation
pyDatalog.create_terms(string)
for fact in abox:
    predicate, arg1, arg2 = parse_triple(fact)
    arg1 = arg1.strip().lower(); arg2 = arg2.strip().lower()
    if arg2 == 'none':
        assert_fact(predicate, arg1)
    else:
        assert_fact(predicate, arg1, arg2) 


def is_transitive(axiom):
    if isinstance(axiom, ForAll):
        axiom = axiom.body
        rhs, lhs = axiom.rhs, axiom.lhs
        lhs_conjs = flatten_and(lhs); rhs_conjs = flatten_and(rhs)
        preds = set()
        args_lhs = []; args_rhs = []
        for lhs_conj in lhs_conjs:
            pred = lhs_conj.name
            args = lhs_conj.args
            preds.add(pred)
            args_lhs.append(args)

        for rhs_conj in rhs_conjs:
            pred = rhs_conj.name
            args = rhs_conj.args
            preds.add(pred)
            args_rhs.append(args)
        if len(preds) > 1:
            return False
        for i in range(len(args_lhs)-1):
            if args_lhs[i][1].name != args_lhs[i+1][0].name:
                return False
        if args_lhs[0][0].name == args_rhs[0][0].name and args_lhs[-1][1].name == args_rhs[-1][1].name:
            return True

    elif isinstance(axiom, str):
        axiom = axiom.split(',', 1)[1].strip()
        lhs, rhs = axiom.split('=>')
        lhs_conjs = lhs.split('&&'); rhs_conjs = rhs.split('&&')
        preds = set()
        args_lhs = []; args_rhs = []
        for lhs_conj in lhs_conjs:
            pred = lhs_conj.split('(')[0].strip()
            arg1, arg2 = lhs_conj.split('(')[1].split(')')[0].split(',')
            arg1 = arg1.strip(); arg2 = arg2.strip()
            preds.add(pred)
            args_lhs.append((arg1, arg2))

        for rhs_conj in rhs_conjs:
            pred = rhs_conj.split('(')[0].strip()
            arg1, arg2 = rhs_conj.split('(')[1].split(')')[0].split(',')
            arg1 = arg1.strip(); arg2 = arg2.strip()
            preds.add(pred)
            args_rhs.append((arg1, arg2))
        if len(preds) > 1:
            return False
        for i in range(len(args_lhs)-1):
            if args_lhs[i][1] != args_lhs[i+1][0]:
                return False
        if args_lhs[0][0] == args_rhs[0][0] and args_lhs[-1][1] == args_rhs[-1][1]:
            return True

def is_recursive(axiom):
        if isinstance(axiom, ForAll):
            axiom = axiom.body
            rhs, lhs = axiom.rhs, axiom.lhs
            lhs_conjs = flatten_and(lhs); rhs_conjs = flatten_and(rhs)
            lhs_preds = set()
            for lhs_conj in lhs_conjs:
                pred = lhs_conj.name
                lhs_preds.add(pred)

            for rhs_conj in rhs_conjs:
                pred = rhs_conj.name
                if pred in lhs_preds:
                    return True
        
        elif isinstance(axiom, str):
            axiom = axiom.split(',', 1)[1].strip()
            lhs, rhs = axiom.split('=>')
            lhs_conjs = lhs.split('&&'); rhs_conjs = rhs.split('&&')
            lhs_preds = set()
            for lhs_conj in lhs_conjs:
                pred = lhs_conj.split('(')[0].strip()
                lhs_preds.add(pred)
            
            for rhs_conj in rhs_conjs:
                pred = rhs_conj.split('(')[0].strip()
                if pred in lhs_preds:
                    return True

        return False


def find_continuation(axiom, axioms):

    continuations = set()
    conjuncts = flatten_and(axiom.body.lhs)
    for candidate in axioms:
        if isinstance(candidate, ForAll):
            rhs = candidate.body.rhs
            for conjunct in conjuncts:
                if conjunct.name == rhs.name and str(axiom) != str(candidate):
                    continuations.add(candidate)
        elif isinstance(candidate, str) and axiom != candidate:
            rhs = candidate.split('=>')[1].strip()
            for conjunct in conjuncts:
                if conjunct == rhs:
                    continuations.add(candidate)

    return continuations

def is_left_recursive(axiom):
    if isinstance(axiom, ForAll):
        lhs, rhs = axiom.body.lhs, axiom.body.rhs
        if not is_recursive(axiom):
            return False
        else:
            lhs_conjs = flatten_and(lhs); rhs_conjs = flatten_and(rhs)
            rhs_preds = set(); lhs_args = set(); rhs_args = set()
            for rhs_conj in rhs_conjs:
                rhs_preds.add(rhs_conj.name)
                for arg in rhs_conj.args:
                    rhs_args.add(arg.name)
            for lhs_conj in lhs_conjs:
                if lhs_conj.name in rhs_preds:
                    for arg in lhs_conj.args:
                        if arg.name in rhs_args:
                            return False
                    return True
        return False


        


class Datalog:
    def __init__(self, path='data', dataset='onto'):
        self.dataset = dataset
        self.abox = self._load_abox(path)
        self._init_datalog()

    def _init_datalog(self):
        pyDatalog.clear()
        relations = set()
        for fact in self.abox:
            predicate, arg1, arg2 = parse_triple(fact)
            arg1 = arg1.strip(); arg2 = arg2.strip()
            relations.add(predicate)
            relations.add(arg1)
            if arg2 != 'none':
                relations.add(arg2)
        with open(os.path.join(f'data/t-box-{self.dataset}.json') , 'r', encoding= 'utf-8') as file:
            tbox = json.load(file)

        for axiom in tbox:
            rhs = axiom.split('=>')[1].strip()
            predicate = rhs.split('(')[0].strip()
            relations.add(predicate)
            lhs = axiom.split('=>')[0].strip()
            lhs = lhs.split(',', 1)[1].strip()
            lhs_conjs = lhs.split('&&')
            for lhs_conj in lhs_conjs:
                pred = lhs_conj.split('(')[0].strip()
                relations.add(pred)

        string = 'A, B, C, D'

        for relation in relations:
            string += ', ' + relation
        pyDatalog.create_terms(string)
        for fact in self.abox:
            predicate, arg1, arg2 = parse_triple(fact)
            arg1 = arg1.strip().lower(); arg2 = arg2.strip().lower()
            if arg2 == 'none':
                assert_fact(predicate, arg1)
            else:
                assert_fact(predicate, arg1, arg2) 



        
    def __call__(self, axioms, query):
        # No negations in database
        if isinstance(query, Not):
            return False

        to_exec_axioms = set()

        for axiom in axioms:
            #is a fact, not an axiom
            if isinstance(axiom, Predicate):
                continue
            
            #elif is_recursive(axiom):
            #    continuation_axioms = _find_continuation(axiom, axioms)
                # axiom needs to be converted to the form accepted by pydatalog
            standard_axiom = self._standarize_axiom(axiom)
            if not is_left_recursive(axiom):
                to_exec_axioms.add(standard_axiom)
                

        
        for axiom in to_exec_axioms:
            
            exec(axiom)

        standard_query = self._standarize_query(query)

        try:
            answer = pyDatalog.ask(standard_query)
            if answer is not None:
                return True
            else:
                return False
        except:
            return False


    def _standarize_query(self, query):
        #?A"
        if isinstance(query, Predicate):
            pred = query.name
            args = query.args
            query_str = pred + "("
            for arg in args:
                query_str += "\'" + arg.name.lower() + "\', "
            query_str = query_str[:-2] + ")"
        
        return query_str
    
    def _standarize_axiom(self, axiom):
        lhs, rhs = axiom.body.rhs, axiom.body.lhs
        lhs_conjuncts = flatten_and(lhs)
        rhs_conjuncts = flatten_and(rhs)
        lhs_str = ''

        for lhs_conjunct in lhs_conjuncts:
            if isinstance(lhs_conjunct, Predicate):
                pred = lhs_conjunct.name
                args = lhs_conjunct.args
                lhs_str += pred + '('
                for arg in args:
                    lhs_str += arg.name.split('?')[1] + ', '
                lhs_str = lhs_str[:-2] + ') & '
        lhs_str = lhs_str[:-3]

        rhs_str = ''
        for rhs_conjunct in rhs_conjuncts:
            if isinstance(rhs_conjunct, Predicate):
                pred = rhs_conjunct.name
                args = rhs_conjunct.args
                rhs_str += pred + '('
                for arg in args:
                    rhs_str += arg.name.split('?')[1] + ', '
                rhs_str = rhs_str[:-2] + ') & '
        rhs_str = rhs_str[:-3]
        axiom_str = lhs_str + ' <= ' + rhs_str

        return axiom_str
        



    def _load_abox(self, file_path='data'):
        with open(os.path.join(file_path, f'abox-{self.dataset}.json') , 'r', encoding='utf-8') as file:
            abox = json.load(file)
        return abox
    


class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException("Timed out!")