from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any

from z3 import (
    BoolSort,
    Bool,
    Const,
    Function,
    ForAll,
    Exists,
    Implies,
    And,
    Or,
    Xor,
    Context,
    Not as zNot,
    DeclareSort,
)

from evaluators.adapters.fol.fol_lark import (
    parse,
    Term,
    Formula,
    Var,
    Const as AConst,
    Func,
    Pred,
    Eq,
    Not,
    And as AAnd,
    Or as AOr,
    Xor as AXor,
    Implies as AImplies,
    Iff,
    Forall,
    Exists as AExists,
)


@dataclass
class Signature:
    ctx: Context | None = None
    U: Any = field(init=False)
    preds: dict[str, tuple[Any, int]] = field(default_factory=dict)
    funcs: dict[str, tuple[Any, int]] = field(default_factory=dict)
    consts: dict[str, Any] = field(default_factory=dict)
    bool_consts: dict[str, Any] = field(default_factory=dict)

    def __post_init__(self):
        self.U = DeclareSort("U", ctx=self.ctx)

    def get_func(self, name: str, arity: int):
        if name in self.preds:
            raise ValueError(f"Symbol '{name}' used as predicate and function")
        if name in self.funcs:
            decl, a = self.funcs[name]
            if a != arity:
                raise ValueError(f"Arity mismatch: '{name}': had:{a} now:{arity}")
            return decl
        if arity == 0:
            return self.get_const(name)
        decl = Function(name, *([self.U] * arity), self.U)
        self.funcs[name] = (decl, arity)
        return decl

    def get_pred(self, name: str, arity: int):
        if name in self.funcs:
            raise ValueError(f"Duplicate {name}")
        if arity == 0:
            return self.get_bool_const(name)
        if name in self.preds:
            decl, a = self.preds[name]
            if a != arity:
                raise ValueError(f"Arity mismatch: '{name}': had:{a} now:{arity}")
            return decl
        decl = Function(name, *([self.U] * arity), BoolSort(ctx=self.ctx))
        self.preds[name] = (decl, arity)
        return decl

    def get_const(self, name: str):        
        c = self.consts.get(name)
        if c is None:
            c = Const(name, self.U)
            self.consts[name] = c
        return c

    def get_bool_const(self, name: str):
        b = self.bool_consts.get(name)
        if b is None:
            b = Bool(name, ctx=self.ctx)
            self.bool_consts[name] = b
        return b


class Z3Translator:
    # env_stack: list[dict[str, Const]]

    def __init__(self, sig: Signature | None = None):
        self.sig = sig or Signature()

    def to_z3(self, f: Formula):
        return self._formula(f, env_stack=[])

    def _formula(self, f: Formula, env_stack: list[dict[str, Any]]):
        # NOTE: FOR >= 3 ARGUMENTS, XOR means 'Exactly one' as opposed to odd number of trues.
        # We have noticed that this is due to FOLIO dataset's design.
        # We have later decided to filter out such cases when building the FOL dataset.
        # Code is not used, but below for reference:
        #
        # if isinstance(f, AXor):
        #     parts = []
        #     self._collect_xor_parts(f, env_stack, parts)
        #     return self._xor_to_z3(parts)
        if isinstance(f, Pred):
            args = [self._term(t, env_stack) for t in f.args]
            pred = self.sig.get_pred(f.name, len(args))
            if len(args) == 0:
                return pred
            # TO FIX TYPE INFERENCE
            if not callable(pred):
                raise TypeError("Predicate not callable")

            return pred(*args)
        elif isinstance(f, Eq):
            left = self._term(f.left, env_stack)
            right = self._term(f.right, env_stack)
            return (left != right) if f.negated else (left == right)
        elif isinstance(f, Not):
            return zNot(self._formula(f.inner, env_stack), ctx=self.sig.ctx)
        elif isinstance(f, AAnd):
            return And(
                self._formula(f.left, env_stack), self._formula(f.right, env_stack)
            )
        elif isinstance(f, AOr):
            return Or(
                self._formula(f.left, env_stack), self._formula(f.right, env_stack)
            )
        elif isinstance(f, AXor):
            return Xor(
                self._formula(f.left, env_stack), self._formula(f.right, env_stack), ctx=self.sig.ctx
            )
        elif isinstance(f, AImplies):
            return Implies(
                self._formula(f.left, env_stack), self._formula(f.right, env_stack), ctx=self.sig.ctx
            )
        elif isinstance(f, Iff):
            return self._formula(f.left, env_stack) == self._formula(f.right, env_stack)
        elif isinstance(f, Forall):
            zvars = []
            frame: dict[str, Any] = {}
            for v in f.vars:
                if v in frame:
                    raise ValueError(f"Duplicate variable '{v}'")
                z = Const(v, self.sig.U)
                frame[v] = z
                zvars.append(z)
            env_stack.append(frame)
            body = self._formula(f.body, env_stack)
            env_stack.pop()
            return ForAll(zvars, body)
        elif isinstance(f, AExists):
            zvars = []
            frame: dict[str, Any] = {}
            for v in f.vars:
                if v in frame:
                    raise ValueError(f"Duplicate variable '{v}'")
                z = Const(v, self.sig.U)
                frame[v] = z
                zvars.append(z)
            env_stack.append(frame)
            body = self._formula(f.body, env_stack)
            env_stack.pop()
            return Exists(zvars, body)
        else:
            raise TypeError(f)

    def _term(self, t: Term, env_stack: list[dict[str, Any]]):
        if isinstance(t, Var):
            z = self._lookup_var(t.name, env_stack)
            if z is None:
                raise ValueError(f"Unbound: '{t.name}'")
            return z
        elif isinstance(t, AConst):
            return self.sig.get_const(t.name)
        elif isinstance(t, Func):
            arg_terms = [self._term(a, env_stack) for a in t.args]
            fun = self.sig.get_func(t.name, len(arg_terms))
            if len(arg_terms) == 0:
                return fun
            # TO FIX TYPE INFERENCE
            if not callable(fun):
                raise TypeError("Function not callable")
            return fun(*arg_terms)
        else:
            raise TypeError(t)

    def _lookup_var(self, name: str, env_stack: list[dict[str, Any]]):
        for frame in reversed(env_stack):
            if name in frame:
                return frame[name]
        return None


def parse_to_z3(formula_text: str, sig: Signature):
    # Takes in a signature so that we could combine a signature across multiple formulas.
    lark_ast = parse(formula_text)
    translator = Z3Translator(sig)
    return translator.to_z3(lark_ast)
