from collections import defaultdict
from itertools import product
from typing import Generic, List, Set, TypeVar, Optional

from pddl.core import And, Formula, Predicate
from pddl.logic.base import Not

from tp_lodge.motion_planning.motion_simulator import MotionSimulator
from tp_lodge.motion_planning.motion_validator import (
    MotionSimulationException,
    MotionSimulationResponseCode,
    MotionValidator,
)
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.utils.pddl_utils import filter_formula_by_predicates, combine_predicates

T = TypeVar("T")


class LocalMotionValidator(MotionValidator, Generic[T]):

    def __init__(self, env: MotionSimulator[T], hide_env_feedback: bool):
        super().__init__()
        self.env = env
        self.hide_env_feedback = hide_env_feedback

    def _run_motion(self, motion: str):
        try:
            self.env.run_motion(motion=motion)
        except TypeError as e:
            raise MotionSimulationException(
                message=str(e),
                code=MotionSimulationResponseCode.PDDL_PY_TRANSLATION,
                expected="",
                ground_truth="",
            )
        except ValueError as e:
            if self.hide_env_feedback:
                raise MotionSimulationException(
                    message="The motion cannot be executed in the current environment state.",
                    code=MotionSimulationResponseCode.EFFECT_FAILED,
                    expected="",
                    ground_truth="",
                )
            else:
                raise MotionSimulationException(
                    message=str(e),
                    code=MotionSimulationResponseCode.EFFECT_FAILED,
                    expected="",
                    ground_truth="",
                )

    def get_supported_predicates(self, domain: PDDLDomain) -> Optional[List[str]]:
        """
        Returns a list of supported predicates for the motion validator.
        This method should be overridden in subclasses to provide specific predicate names.
        """
        raise NotImplementedError("get_supported_predicates() must be implemented in subclasses")

    def validate_predicate(self, pred: Predicate) -> bool:
        """
        Validate a predicate against the current state of the environment.
        This method should be overridden in subclasses to provide specific validation logic.
        """
        raise NotImplementedError("validate_predicate() must be implemented in subclasses")

    def get_predicates_evaluation(self, domain: PDDLDomain, problem: PDDLProblem) -> List[Formula]:
        objects_per_type = defaultdict(list)
        for obj in problem.objects:
            for t in domain.parent_types(obj.type_tag):
                objects_per_type[t].append(obj)

        predicate_evaluations = []
        for pddl_predicate in domain.predicates:
            predicate = pddl_predicate.definition
            term_types = [list(t.type_tags)[0] for t in predicate.terms]
            objects_per_arg = [objects_per_type[term_type] for term_type in term_types]
            for args in list(product(*objects_per_arg)):
                predicate = Predicate(predicate.name, *args)
                try:
                    is_valid = self.validate_predicate(predicate)
                except ValueError:
                    # predicate not supported
                    # TODO: later, all predicates given in initial domain should be grounded
                    continue

                if is_valid:
                    predicate_evaluations.append(predicate)
                else:
                    predicate_evaluations.append(Not(predicate))

        return predicate_evaluations

    def validate_predicates(self, predicates: Set[Formula], domain: PDDLDomain, problem: PDDLProblem):
        gt_predicates = self.get_predicates_evaluation(domain=domain, problem=problem)
        supported_predicates = filter_formula_by_predicates(
            And(*predicates), known_predicates=self.get_supported_predicates(domain=domain)
        ).operands
        # if not is_predicates_subset(supported_predicates, subset_of=gt_predicates):
        if set(supported_predicates) != set(gt_predicates):
            from tp_lodge.utils.pddl_utils import get_matching_predicates, get_pred_change

            # get the predicates that don't match
            exp_init_state, removed, added = get_pred_change(gt_predicates, supported_predicates)
            for removed_pred in removed:
                exp_init_state.append(Not(removed_pred))
            for added_pred in added:
                exp_init_state.append(added_pred)
            gt_init_state = get_matching_predicates(exp_init_state, gt_predicates)
            raise MotionSimulationException(
                ground_truth=str(And(*gt_init_state)),
                expected=str(And(*exp_init_state)),
                code=MotionSimulationResponseCode.EFFECT_FAILED,
            )

    def get_env_hash(self):
        return self.env.get_hash()

    def set_env_hash(self, hash):
        return self.env.set_hash(hash=hash)

    def init_state(self, seed: int):
        return self.env.init_state(seed=seed)

    def inject_init_predicates(self, domain, problem):
        predicates = self.get_predicates_evaluation(domain=domain, problem=problem)
        supported_predicates = self.get_supported_predicates(domain=domain)
        if supported_predicates is not None and problem.initial_state is not None:
            custom_predicates = filter_formula_by_predicates(
                problem.initial_state, known_predicates=supported_predicates, inverse=True
            )
            predicates = combine_predicates(predicates, custom_predicates)
        problem = problem.copy_with(initial_state=predicates, grounder_initial_state=predicates)
        return problem
