import logging
from pathlib import Path
from typing import List, Optional, Tuple

from llm_utils.openai_api.chat import Chat
from llm_utils.textgen_api.textgen_api import TextGenApi
from pddl.core import And

from tp_lodge.motion_planning.dummy_motion_validator import DummyMotionValidator
from tp_lodge.motion_planning.motion_validator import (
    MotionSimulationErrorReason,
    MotionSimulationErrorType,
    MotionSimulationException,
    MotionSimulationResponseCode,
    MotionValidator,
)
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.task_planning.models.planning.plan_result import (
    PlanActionResult,
    PlanResult,
)
from tp_lodge.task_planning.models.sas.sas_action import SasAction
from tp_lodge.task_planning.pddl_planner.hi_planner.composite_action_node import CompositeActionNode
from tp_lodge.task_planning.pddl_planner.hi_planner.shared_action_node_storage import SharedActionNodeStorage
from tp_lodge.utils.pddl_utils import (
    filter_formula_by_predicates,
    get_aligned_effects,
    get_constants_in_term,
)

logger = logging.getLogger(__name__)


class PDDLCompositeActionNode(CompositeActionNode):

    def __init__(
        self,
        operator_id: str,
        parent_operators: List[str],
        storage: SharedActionNodeStorage,
        motion_validator: MotionValidator,
        textgen_api: TextGenApi,
        out_dir: Path,
        #
        sas_action: SasAction,
        py_functions: List[str],
        parent_chat: Chat,
        plan_result: PlanResult,
        env_hash: str,
        domain: PDDLDomain,
        problem: PDDLProblem,
    ):
        super().__init__(
            operator_id=operator_id,
            parent_operators=parent_operators,
            storage=storage,
            motion_validator=motion_validator,
            textgen_api=textgen_api,
            out_dir=out_dir,
            catch_all_domain_errors=False,
            parent_chat=parent_chat,
            plan_result=plan_result,
            env_hash=env_hash,
            last_domain=domain,
            last_problem=problem,
        )
        self.sas_action = sas_action
        self.py_functions = py_functions

        self.parent_problem = problem

    @property
    def parent_domain(self) -> PDDLDomain:
        return self.last_domain

    def _generate_domain(self, replan: bool) -> Tuple[PDDLProblem, PDDLDomain, Chat]:
        assert self.last_problem is not None

        return self.generate_domain(
            action=self.parent_domain.get_operator_by_id(self.operator_id),
            sas_action=self.sas_action,
            function_stubs=self.storage.function_stubs,
            out_dir=self.out_dir,
            problem=self.last_problem,
            function_calls=self.py_functions,
            chat=(
                Chat.concat_chats(self.parent_chat, self.inner_chat)
                if not self.storage.flatten_chat_history
                else self.inner_chat
            ),
            domain=self.last_domain.remove_for_level(self.parent_operators),
            add_chat_as_message=not replan,
        )

    def infer(self, replan: bool) -> Tuple[PlanActionResult, str]:
        """this action is a composite. We plan the decomposition"""
        # if we remove the actions it only makes sense to also remove the chat history.
        # otherwise, the llm thinks it can use the actions in the chat history
        result, _, out_env_hash = self._plan(replan_outer=replan)

        # we now executed the subplan. We must not validate that the actions effects match with the observed effect of the subactions, since
        # they can overshoot the actions effects
        # problem_to_check = action_problem
        problem_to_check = self.parent_problem

        if isinstance(self.motion_validator, DummyMotionValidator):
            return result, out_env_hash

        assert problem_to_check is not None
        assert problem_to_check.initial_state is not None
        assert problem_to_check.goal_state is not None
        exp_effects, gt_effects = get_aligned_effects(
            gt_prior=problem_to_check.initial_state,
            gt_post=self.motion_validator.get_predicates_evaluation(
                domain=self.parent_domain, problem=self.parent_problem
            ),
            exp_prior=problem_to_check.initial_state,
            exp_post=problem_to_check.goal_state,
            supported_predicates=self.motion_validator.get_supported_predicates(domain=self.parent_domain),
        )
        parent_predicates = self.last_domain.remove_for_level(self.parent_operators[:-1]).predicates
        parent_pred_names = [p.name for p in parent_predicates]
        exp_effects = filter_formula_by_predicates(exp_effects, known_predicates=parent_pred_names)
        gt_effects = filter_formula_by_predicates(gt_effects, known_predicates=parent_pred_names)

        if not gt_effects == exp_effects:
            # the effects of the planner deviate from the simulation. This can have two causes
            # 1. the subplan results in side-effects, so it changes predicates of variables not included in this sub-plans parameters
            # 2. the subplan overshoots the subplan goal
            # Here we have to determine what cause it is.
            # If it is 1.: we do not need to replan, but should subsequently update our state
            # If it is 2.: we need to replan the parent domain
            # self.last_domain.remove_for_level(self.last_domain.get_operator_by_id(self.operator_id).parent_operator_id
            logger.info(
                "Decomposition of %s resulted in goal overshoot. Replanning parent domain required.\n    Expected:     %s\n    Ground Truth: %s"
                % (self.sas_action.to_string(), str(exp_effects), str(gt_effects))
            )
            # replan = True

            # variables = self.sas_action.args
            # from tp_lodge.utils.pddl_utils import filter_formula_by_constants

            # in_subplan_gt_effects = filter_formula_by_constants(gt_effects, variables)
            # in_subplan_exp_effects = filter_formula_by_constants(exp_effects, variables)

            e = MotionSimulationException(
                expected=str(exp_effects),
                ground_truth=str(gt_effects),
                code=MotionSimulationResponseCode.EFFECT_FAILED,
                replan_plan=False,
            )
            filter_init_state = True
            if filter_init_state:
                effect_constants = [
                    c.name for c in set(get_constants_in_term(exp_effects) + get_constants_in_term(gt_effects))
                ]
                filtered_init_state = [
                    p
                    for p in problem_to_check.initial_state
                    # if any(c.name in effect_constants for c in get_constants_in_term(p))
                    if all(c.name in effect_constants for c in get_constants_in_term(p))
                ]
            else:
                filtered_init_state = problem_to_check.initial_state
            filtered_init_state = filter_formula_by_predicates(filtered_init_state, known_predicates=parent_pred_names)
            e = e.copy_with(
                reason=MotionSimulationErrorReason(
                    explanation=e.to_observation_message(
                        filtered_init_state, self.sas_action.to_string(), only_valid_init=False
                    ),
                    pddl_operators=[self.sas_action.name],
                    error_type=MotionSimulationErrorType.PDDL_FIX,
                )
            )
            if self.interactive:
                input("Continue?")
            raise e

        return result, out_env_hash
