from typing import Dict, List, Tuple

from pddl.core import Action, Constant, Formula, Predicate
from pddl.logic.base import BinaryOp, ExistsCondition, ForallCondition, UnaryOp, Variable
from pddl.logic.effects import When
from pddl.logic.predicates import EqualTo

from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.utils.pddl_utils import flatten_formula_to_predicates


def verify_predicate(
    g_predicate: Predicate, l_predicate: Predicate, objects: List[Constant], type_hierarchy: Dict[str, List[str]]
) -> None:
    """Verify that the grounded predicate `g_predicate` is an instance of the lifted predicate `l_predicate`.

    Args:
        g_predicate: The grounded predicate to verify
        l_predicate: The lifted predicate definition
        objects: List of available objects for type checking
        type_hierarchy: Dictionary mapping type names to their parent types (including themselves)
    """
    if len(g_predicate.terms) != len(l_predicate.terms):
        raise ValueError(
            f"Predicate {g_predicate.name} has {len(g_predicate.terms)} terms, "
            f"but definition has {len(l_predicate.terms)} terms."
        )

    for term, expected_term in zip(g_predicate.terms, l_predicate.terms):
        assert expected_term.type_tags is not None and len(expected_term.type_tags) == 1
        exp_term_type = next(iter(expected_term.type_tags))
        term_type = next(iter(term.type_tags)) if term.type_tags else None

        if term_type is not None:
            if term_type not in type_hierarchy:
                raise ValueError(f"{term_type} is not a valid type.")

            if exp_term_type not in type_hierarchy[term_type]:
                raise ValueError(
                    f"Predicate `{g_predicate}` error: Object {term.name} type {term_type} is no subtype of expected type {exp_term_type}."
                )

        object = next(filter(lambda o: o.name == term.name, objects), None)
        if object is None:
            raise ValueError(f"Object {term.name} is not a valid object.")

        assert object.type_tag is not None
        if exp_term_type not in type_hierarchy[object.type_tag]:
            raise ValueError(
                f"Predicate `{g_predicate}` error: Object {term.name} type {object.type_tag} is no subtype of expected type {exp_term_type}."
            )


def verify_formula(formula: Formula, variables: List[Variable], domain: PDDLDomain):
    if isinstance(formula, BinaryOp):
        for operand in formula.operands:
            verify_formula(operand, variables, domain)
    elif isinstance(formula, Predicate):
        terms = {term.name: Constant(f"t{i}", list(term.type_tags)[0]) for i, term in enumerate(variables)}

        known_predicates_by_name = {p.definition.name: p.definition for p in domain.predicates}
        if any(t.name not in terms for t in formula.terms):
            raise ValueError(f"Predicate `{formula.name}` has terms that are not defined in the variables.")
        grounded_pred = Predicate(formula.name, *[terms[t.name] for t in formula.terms])
        known_predicate = known_predicates_by_name.get(formula.name)
        if known_predicate is None:
            raise ValueError(f"Predicate `{formula.name}` is not defined in the domain.")

        verify_predicate(
            g_predicate=grounded_pred,
            l_predicate=known_predicate,
            objects=list(terms.values()),
            type_hierarchy=domain.get_type_hierarchy(),
        )
    elif isinstance(formula, UnaryOp):
        verify_formula(formula.argument, variables, domain)
    elif isinstance(formula, (List, Tuple)):
        for item in formula:
            verify_formula(item, variables, domain)
    elif isinstance(formula, EqualTo):
        return []
    elif isinstance(formula, (ExistsCondition, ForallCondition)):
        verify_formula(formula.condition, list(formula.variables) + list(variables), domain)
    elif isinstance(formula, When):
        verify_formula(formula.condition, variables, domain)
        verify_formula(formula.effect, variables, domain)
    else:
        raise NotImplementedError("Unsupported formula type: %s" % str(type(formula)))


def verify_action(action: Action, domain: PDDLDomain):
    verify_formula(action.precondition, action.parameters, domain)
    verify_formula(action.effect, action.parameters, domain)
    if len(flatten_formula_to_predicates(action.effect)) == 0:
        raise ValueError(f"Action `{action.name}` has no effects. Every action must have at least one effect. If necessary, define new predicates.")
