import logging
from pathlib import Path
from tp_lodge.utils.pddl_utils import filter_formula_by_predicates, combine_predicates
from abc import abstractmethod
from typing import List, Optional, Set, Callable

from pddl.core import And, Formula


from pddl.core import And, Formula

from tp_lodge.motion_planning.motion_validator import (
    MotionSimulationException,
    MotionSimulationResponseCode,
    MotionValidator,
)
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.utils.pddl_utils import get_valid_predicates, is_predicates_subset

from state_estimation.predicate_grounder import PredicateGrounder
from state_estimation.se_variable import SEVariable

logger = logging.getLogger(__name__)


class SEMotionValidator(MotionValidator):

    def __init__(self, grounder: PredicateGrounder):
        self.grounder = grounder

    def get_supported_predicates(self, domain: PDDLDomain) -> Optional[List[str]]:
        return [p.name for p in domain.predicates if p.is_visual or p.predefined]

    @abstractmethod
    def get_variables(self) -> List[SEVariable]:
        raise NotImplementedError("This method should be implemented by subclasses.")

    def validate(
        self,
        motion: str,
        problem: PDDLProblem,
        domain: PDDLDomain,
        env_hash: str,
        out_dir: Path,
        post_run_motion_callback: Optional[Callable[[str], None]] = None,
        post_env_hash: Optional[str] = None,
    ) -> str:
        try:
            return super().validate(motion, problem, domain, env_hash, out_dir, post_run_motion_callback, post_env_hash)
        except MotionSimulationException as e:
            logger.info("\n" + "\n".join(map(str, self.get_variables())))
            raise e

    def get_predicates_evaluation(self, domain: PDDLDomain, problem: PDDLProblem, verify: bool = True) -> Formula:
        vars = self.get_variables()

        # check for flat typing hierarchy
        assert all([len(domain.child_types(t)) == 1 for t in domain.types.keys()])

        predicates_evaluation = self.grounder.ground_state(predicates=domain.predicates, variables=vars, verify=verify)

        # pos_predicates = [p for p in predicates_evaluation if not isinstance(p, Not)]
        # print(And(*pos_predicates))

        return And(*predicates_evaluation)

    def validate_predicates(self, predicates: Set[Formula], domain: PDDLDomain, problem: PDDLProblem):
        gt_predicates = self.get_predicates_evaluation(domain=domain, problem=problem)
        predicates = filter_formula_by_predicates(
            And(*predicates), known_predicates=self.get_supported_predicates(domain)
        )
        if not is_predicates_subset(predicates, subset_of=gt_predicates):
            raise MotionSimulationException(
                ground_truth=str(And(*get_valid_predicates(gt_predicates))),
                expected=str(And(*predicates)),
                code=MotionSimulationResponseCode.EFFECT_FAILED,
            )

    def inject_init_predicates(self, domain: PDDLDomain, problem: PDDLProblem):
        predicates_evaluation = self.get_predicates_evaluation(domain=domain, problem=problem)
        if problem.initial_state is not None:
            custom_predicates = filter_formula_by_predicates(
                problem.initial_state, known_predicates=self.get_supported_predicates(domain), inverse=True
            )
            predicates_evaluation = combine_predicates(predicates_evaluation, custom_predicates)
        return problem.copy_with(
            initial_state=predicates_evaluation,
            grounder_initial_state=predicates_evaluation,
        )
