from __future__ import annotations
from dataclasses import dataclass
from lark import Lark, Tree, Token

FOL_GRAMMAR = r"""
NOT: "¬" | "~" | "not"
AND: "∧" | "&" | "and"
OR:  "∨" | "|" | "or"
XOR: "⊕" | "xor" | "^"
IMPLIES: "→" | "->" | "=>"
IFF: "↔" | "⟷" | "<->" | "<=>" | "iff"

FORALL.5: "∀" | "forall"
EXISTS.5: "∃" | "exists"

EQ: "="
NEQ: "≠" | "!="

LPAR: "("
RPAR: ")"
COMMA: ","

VAR.2: /[a-z][A-Za-z0-9_']*/

NAME: /[A-Za-z0-9][A-Za-z0-9_-]*/

%import common.WS
%ignore WS

start: formula

?formula: equivalence

?equivalence: implication
            | implication IFF equivalence          -> iff

?implication: disjunction
            | disjunction IMPLIES implication      -> implies

?disjunction: conjunction
            | disjunction OR  conjunction          -> or
            | disjunction XOR conjunction          -> xor

?conjunction: unary
            | conjunction AND unary                -> and

?unary: NOT unary                                  -> neg
      | quantified
      | primary

?primary: pred_app
        | equality
        | LPAR formula RPAR                        -> parens

quantified: quant_prefix unary

quant_prefix: FORALL var_list                      -> forall
            | EXISTS var_list                      -> exists

var_list: VAR ( (COMMA VAR) | VAR )*

equality: term (EQ|NEQ) term                        -> eq

pred_app: NAME LPAR [ terms ] RPAR                  -> pred

?term: (NAME|VAR) LPAR [ terms ] RPAR               -> fun
    | (NAME|VAR)                                    -> sym

terms: term (COMMA term)*
"""


# These are here to indicate the AST structure and
# intentionally do not contain anything!!
class Term:
    pass


class Formula:
    pass


@dataclass(frozen=True)
class Var(Term):
    name: str


@dataclass(frozen=True)
class Const(Term):
    name: str


@dataclass(frozen=True)
class Func(Term):
    name: str
    args: list[Term]


@dataclass(frozen=True)
class Pred(Formula):
    name: str
    args: list[Term]


@dataclass(frozen=True)
class Eq(Formula):
    left: Term
    right: Term
    negated: bool = False


@dataclass(frozen=True)
class Not(Formula):
    inner: Formula


@dataclass(frozen=True)
class And(Formula):
    left: Formula
    right: Formula


@dataclass(frozen=True)
class Or(Formula):
    left: Formula
    right: Formula


@dataclass(frozen=True)
class Xor(Formula):
    left: Formula
    right: Formula


@dataclass(frozen=True)
class Implies(Formula):
    left: Formula
    right: Formula


@dataclass(frozen=True)
class Iff(Formula):
    left: Formula
    right: Formula


@dataclass(frozen=True)
class Forall(Formula):
    vars: list[str]
    body: Formula


@dataclass(frozen=True)
class Exists(Formula):
    vars: list[str]
    body: Formula


_parser = Lark(FOL_GRAMMAR, start="start", parser="lalr", maybe_placeholders=False)


def parse(formula_text: str) -> Formula:
    tree = _parser.parse(formula_text)
    return _build_formula(tree, bound_stack=[])


def _extract_first_tree(children):
    for ch in children:
        if isinstance(ch, Tree):
            return ch
    return None


def _build_formula(node: Tree | Token, bound_stack: list[set[str]]) -> Formula:
    if not isinstance(node, Tree):
        raise AssertionError("Token where formula expected")

    typ = node.data

    if typ == "start":
        return _build_formula(node.children[0], bound_stack)

    if typ == "parens":
        inner = _extract_first_tree(node.children)
        if inner is None:
            raise AssertionError("Empty parentheses")
        return _build_formula(inner, bound_stack)

    if typ == "neg":
        inner = _extract_first_tree(node.children)
        if not inner:
            raise AssertionError("Empty negation")
        return Not(_build_formula(inner, bound_stack))

    if typ in ("and", "or", "xor", "implies", "iff"):
        left = _build_formula(node.children[0], bound_stack)
        right = _build_formula(node.children[2], bound_stack)
        cls = {"and": And, "or": Or, "xor": Xor, "implies": Implies, "iff": Iff}[typ]
        return cls(left, right)

    if typ == "eq":
        left = _build_term(node.children[0], bound_stack)
        op = node.children[1]
        if not isinstance(op, Token):
            raise AssertionError("Expected Token type")
        right = _build_term(node.children[2], bound_stack)
        return Eq(left, right, negated=(op.type == "NEQ"))

    if typ == "pred":
        name_tok = node.children[0]
        args: list[Term] = []
        for ch in node.children:
            if isinstance(ch, Tree) and ch.data == "terms":
                args = [
                    _build_term(c, bound_stack)
                    for c in ch.children
                    if isinstance(c, Tree)
                ]
        return Pred(str(name_tok), args)

    if typ == "quantified":
        qprefix: Tree = node.children[0]
        var_list_tree = next(
            (
                ch
                for ch in qprefix.children
                if isinstance(ch, Tree) and ch.data == "var_list"
            ),
            None,
        )
        vars_ = (
            [
                str(tok)
                for tok in var_list_tree.scan_values(
                    lambda v: isinstance(v, Token) and v.type == "VAR"
                )
            ]
            if var_list_tree
            else []
        )
        bound_stack.append(set(vars_))
        body = _build_formula(node.children[1], bound_stack)
        bound_stack.pop()
        return Forall(vars_, body) if qprefix.data == "forall" else Exists(vars_, body)

    if typ in (
        "equivalence",
        "implication",
        "disjunction",
        "conjunction",
        "unary",
        "primary",
        "formula",
    ):
        inner = _extract_first_tree(node.children)
        if inner is None:
            raise AssertionError(f"Unexpected empty node {typ}")
        return _build_formula(inner, bound_stack)

    raise AssertionError(f"Unhandled node type: {typ}")


def _build_term(node: Tree | Token, bound_stack: list[set[str]]) -> Term:
    if not isinstance(node, Tree):
        raise AssertionError("Token at term level")

    typ = node.data
    if typ == "fun":
        name = str(node.children[0])
        args: list[Term] = []
        for ch in node.children:
            if isinstance(ch, Tree) and ch.data == "terms":
                args = [
                    _build_term(c, bound_stack)
                    for c in ch.children
                    if isinstance(c, Tree)
                ]
        return Func(name, args)

    if typ == "sym":
        name = str(node.children[0])
        is_var = any(name in scope for scope in reversed(bound_stack))
        return Var(name) if is_var else Const(name)

    if typ == "parens":
        inner = _extract_first_tree(node.children)
        if inner is None:
            raise AssertionError("Empty parentheses in term")
        return _build_term(inner, bound_stack)

    trees = [ch for ch in node.children if isinstance(ch, Tree)]
    if len(trees) == 1:
        return _build_term(trees[0], bound_stack)

    raise AssertionError(f"Unhandled term node {typ}")

if __name__ == "__main__":
    example = "∀x ((Student(x) ∧ ∃y (Take(x,y) ∧ InstructedBy(y, professorDavid))) → Take(x, databaseCourse))"
    parsed = parse(example)
    print(parsed)