from lark import Lark, Transformer
import sys, re

### FOL classes ### 

class Constant:
    def __init__(self, name, type=None):
        self.name = name
        self.type = type

    def __repr__(self) -> str:
        return self.name


class Variable:
    def __init__(self, name, type=None):
        self.name = name
        self.type = type

    def __repr__(self) -> str:
        return self.name


class Predicate:
    def __init__(self, name, *args):
        self.name = name
        self.args = args

    def __repr__(self) -> str:
        args = ", ".join(map(str, self.args))

        return f"{self.name}({args})"


class Function:
    def __init__(self, name, *args):
        self.name = name
        self.args = args

    def __repr__(self) -> str:
        args = ", ".join(map(str, self.args))

        return f"{self.name}({args})"


class Not:
    def __init__(self, arg):
        self.arg = arg

    def __repr__(self) -> str:
        return f"~{str(self.arg)}"


class And:
    def __init__(self, arg1, arg2):
        self.arg1 = arg1
        self.arg2 = arg2

    def __repr__(self) -> str:
        return f"{str(self.arg1)} && {str(self.arg2)}"


class Or:
    def __init__(self, arg1, arg2):
        self.arg1 = arg1
        self.arg2 = arg2

    def __repr__(self) -> str:
        return f"{str(self.arg1)} || {str(self.arg2)}"


class Implies:
    def __init__(self, lhs, rhs):
        self.lhs = lhs
        self.rhs = rhs

    def __repr__(self) -> str:
        return f"{str(self.lhs)} => {str(self.rhs)}"


class ForAll:
    def __init__(self, var, body):
        self.var = var
        self.body = body

    def __repr__(self) -> str:
        return f"FOR_ALL {self.var}, {str(self.body)}"


class Exists:
    def __init__(self, var, body):
        self.var = var
        self.body = body

    def __repr__(self) -> str:
        return f"EXISTS {self.var}, {str(self.body)}"

### FOL parser ### 

class FOLTransformer(Transformer):
    def CONSTANT(self, c):
        try:
            name, type = c.split(':')
            return Constant(name.strip(), type.strip())
        except:
            return Constant(c)

    def VARIABLE(self, v):
        try:
            name, type = v.split(':')
            if name.startswith('?'):
                return Variable(name.strip(), type.strip())
            else:
                return Variable(name.strip(), type.strip())
        except:
            return Variable(v)
        

    def function(self, children):
        name = children[0].value
        args = children[1:]

        return Function(name, *args)

    def predicate(self, children):
        name = children[0].value
        args = children[1:]

        return Predicate(name, *args)

    def negation(self, children):
        arg = children[-1]

        return Not(arg)

    def conjunction(self, children):
        arg1, arg2 = children[0], children[-1]

        return And(arg1, arg2)

    def disjunction(self, children):
        arg1, arg2 = children[0], children[-1]

        return Or(arg1, arg2)

    def implication(self, children):
        lhs, rhs = children[0], children[-1]

        return Implies(lhs, rhs)

    def quantifier(self, children):
        quantifier = children[0].value
        var = children[1]
        body = children[2]

        if quantifier == "FOR_ALL":
            return ForAll(var, body)
        else:
            return Exists(var, body)

with open("utils/logic/fol.lark", "r") as f:
    parser = Lark(f, start="sentence")

transformer = FOLTransformer()


def parse_fol(s): 

    tree = parser.parse(s)
    fol = transformer.transform(tree)
    return tree, fol

### FOL functions ### 

def and_list(conjuncts):
    result = None

    for conjunct in conjuncts:
        if result:
            result = And(result, conjunct)
        else:
            result = conjunct 
    
    return result

def or_list(disjuncts):
    result = None

    for disjunct in disjuncts:
        if result:
            result = Or(result, disjunct)
        else:
            result = disjunct 
    
    return result

def flatten_and(sentence):
    if not isinstance(sentence, And):
        return [sentence]

    return flatten_and(sentence.arg1) + flatten_and(sentence.arg2)

def flatten_or(sentence):
    if not isinstance(sentence, Or):
        return [sentence]

    return flatten_or(sentence.arg1) + flatten_or(sentence.arg2)


def parse_predicate_string(s):
    predicate_part, variables_part = s.split('|')
    pred_match = re.match(r'([\w\s\-]+)\s*\(([^)]*)\)', predicate_part.strip())
    if not pred_match:
        raise ValueError("Invalid predicate format")
    pred = pred_match.group(1)
    var_strings = pred_match.group(2).split(',')
    args = []
    var_type_pairs = variables_part.split(',')
    for pair in var_type_pairs:
        var, var_type = pair.split(':')
        args.append(Variable(var.strip(), var_type.strip()))
    
    return Predicate(pred, [*args])


def substitute(theta, clause):
    if theta is None:
        return None
    elif isinstance(clause, Predicate):
        substituted_args = [substitute(theta, arg) for arg in clause.args]
        return Predicate(clause.name, *substituted_args)
    elif isinstance(clause, Variable):
        for key, value in theta.items():
            if key == clause.name:
                name, type = value[0], value[1]
                if '?' in name:
                    return Variable(name, type)
                else:
                    return Constant(name, type)
        return clause
    elif isinstance(clause, ForAll):
        lhs = substitute(theta, clause.body.lhs)
        rhs = substitute(theta, clause.body.rhs)
        return ForAll(clause.var, Implies(lhs, rhs))
    elif isinstance(clause, And):
        conjuncts = flatten_and(clause)
        return and_list([substitute(theta, conjunct) for conjunct in conjuncts])
    elif isinstance(clause, Or):
        disjuncts = flatten_or(clause)
        return or_list([substitute(theta, disjunct) for disjunct in disjuncts])
    else:
        return clause
        
    

def unify(x, y, theta={}):
    """
    returns a substitution to make x and y identical
    x , a variable, constant, list, or compound expression
    y , a variable, constant, list, or compound expression
    theta, the substitution built up so far (optional, default = {})
    """
    if theta is None:
        return None
    elif str(x) == str(y):
        return theta
    elif isinstance(x, Constant) and isinstance(y, str):
        if x.name == y:
            return theta
    elif isinstance(y, Constant) and isinstance(x, str):
        if y.name == x:
            return theta
    elif isinstance(x, Constant) and isinstance(y, tuple):
        if x.name == y[0]:
            return theta
    elif isinstance(y, Constant) and isinstance(x, tuple):
        if y.name == x[0]:
            return theta
    elif isinstance(x, Variable):
        return unify_var(x, y, theta)
    elif isinstance(y, Variable):
        return unify_var(y, x, theta)
    elif isinstance(x, Predicate) and isinstance(y, Predicate):
        return unify(x.args, y.args, unify(x.name, y.name, theta))
    elif isinstance(x, tuple) and isinstance(y, tuple):
        if len(x) == 1:
            return unify(x[0], y[0], theta)
        else:
            return unify(x[1:], y[1:], unify(x[0], y[0], theta))
    else:
        return None
        
def unify_var(var, x, theta):
    """
    returns a substitution to make var and x identical
    """

    if var.name in theta:
        return unify(theta[var.name], x, theta)
    elif x.name in theta:
        return unify(var, theta[x.name], theta)
    else:
        theta[var.name] = (x.name, x.type)
        return theta
    
def unify_clauses(clause1, clause2):
    """
    returns a substitution to make clause1 and clause2 identical
    """
    if isinstance(clause1, Predicate) and isinstance(clause2, Predicate):
        return unify(clause1, clause2)
    



if __name__ == "__main__":
    #s = "FOR_ALL x, meat_pasta(x) && watching_weight(x) && ~spicy(x) => likes(U, x)"
    #s = "FOR_ALL ?x:food food, meat_pasta(?x:food food) && watching_weight(?x:food food) => likes(Khunyang Chhish:human human, x:food food)"
    s = "FOR_ALL ?A:location ?B:location ?C:location, PartOf(?A:location, ?B:location) && PartOf(?B:location, ?C:location) => PartOf(?A:location, ?C:location)"
    #s = "ContainedIn(?A:location, ?A:location)"
    #s = "meat_pasta(x) | x:food"
    tree, fol = parse_fol(s)
    
    print(isinstance(fol.body.lhs, And))
    sys.exit()


    print(tree.pretty())
    print(type(fol.body.rhs.args[1]))
    print()

    print(type(fol))
    print(f"{type(fol.var)} -> {fol.var}")
    print(f"{type(fol.body)} -> {fol.body}")
    print(f"{type(fol.body.lhs)} -> {fol.body.lhs}")
    print(f"{type(fol.body.lhs.arg1)} -> {fol.body.lhs.arg1}")
    # you can keep checking the types and their contents...  
    
    # alternatively, define a list of conjuncts 
    print("\nalternative:")
    conjuncts = [
        Predicate("meat_pasta", Variable("x")),
        Predicate("watching_weight", Variable("x")),
        Not(Predicate("spicy", Constant("r1")))
    ]
    ands = and_list(conjuncts)
    flatenned_list = flatten_and(ands)
    print(flatenned_list[0].name)
    print(flatenned_list[0].args[0])
    print(type(flatenned_list[-1].arg.args[0]))
    print(isinstance(flatenned_list[-1].arg.args[0], Constant))
    print((Not(Predicate("spicy", Variable("x")))).arg.name)
    print(Predicate("spicy", Variable("x")).args)
    print(Predicate("spicy", Constant("r1")).args)