import logging
from pathlib import Path
from typing import Optional, Tuple

from llm_utils.openai_api.chat import Chat
from llm_utils.textgen_api.textgen_api import TextGenApi

from tp_lodge.motion_planning.motion_validator import (
    MotionSimulationException,
    MotionSimulationResponseCode,
    MotionValidator,
)
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.task_planning.models.planning.plan_result import (
    PlanActionResult,
    PlanLeafActionResult,
    PlanResult,
)
from tp_lodge.task_planning.models.sas.sas_action import SasAction
from tp_lodge.task_planning.pddl_planner.hi_planner.action_node import ActionNode
from tp_lodge.task_planning.pddl_planner.hi_planner.mixins.gen_definitions_mixin import verify_retries
from tp_lodge.task_planning.pddl_planner.hi_planner.shared_action_node_storage import SharedActionNodeStorage
from tp_lodge.utils.planning_cache_utils import parameterize_skill

logger = logging.getLogger(__name__)


class PDDLPredefinedActionNode(ActionNode):

    def __init__(
        self,
        operator_id: str,
        parent_operators: list[str],
        storage: SharedActionNodeStorage,
        motion_validator: MotionValidator,
        textgen_api: TextGenApi,
        out_dir: Path,
        #
        parent_chat: Chat,
        sas_action: SasAction,
        py_function: str,
        plan_result: PlanResult,
        env_hash: str,
        problem: PDDLProblem,
        domain: PDDLDomain,
    ):
        super().__init__(
            operator_id=operator_id,
            parent_operators=parent_operators,
            storage=storage,
            motion_validator=motion_validator,
            textgen_api=textgen_api,
            out_dir=out_dir,
        )

        # setup
        self.parent_chat = parent_chat
        self.sas_action = sas_action
        self.py_function = py_function
        self.plan_result = plan_result
        self.env_hash = env_hash
        self.problem = problem
        self.domain = domain

    @verify_retries
    def _motion_planning_for_skill(self, replan: bool) -> Tuple[str, str]:
        logger.info("Run motion validation for %s..." % self.py_function.strip())

        py_function_file = self.out_dir / "skills-code.py"

        try:
            env_file = self.out_dir / "env-state.hash"
            if env_file.is_file() and not replan:
                post_env_hash = env_file.read_text()
                logger.info("Nothing in motion changed -> used cached env hash")
            else:
                post_env_hash = None

            action = self.domain.get_operator_by_id(self.operator_id)
            try:
                parameterized_py_function = parameterize_skill(
                    skill=self.py_function, sas_action=self.sas_action, action=action
                )
            except ValueError as e:
                raise MotionSimulationException(
                    expected="",
                    ground_truth="",
                    code=MotionSimulationResponseCode.PDDL_PY_TRANSLATION,
                    message="Parameterization of skill failed: %s" % str(e),
                )

            out_env_hash = self.motion_validator.validate(
                motion=parameterized_py_function,
                problem=self.problem,
                domain=self.domain.remove_for_level(self.parent_operators),
                out_dir=self.out_dir,
                env_hash=self.env_hash,
                post_env_hash=post_env_hash,
                post_run_motion_callback=lambda env_hash: env_file.write_text(env_hash)
            )
            action.mapped_skill_sequence = [self.py_function]
            py_function_file.write_text(self.py_function)
            env_file.write_text(out_env_hash)
            logger.info("Motion validated")
            return parameterized_py_function, out_env_hash
        except MotionSimulationException as e:
            self.storage.add_env_retry()
            if self.interactive and input("Retry ([y]/n)") in ["", "y"]:
                logger.warning("Motion Simulation Exception -> Just retry due to non-determinism")
                self.storage.add_val_retry()
                self._motion_planning_for_skill(replan=replan)
            else:
                raise e
        raise RuntimeError()

    def infer(self, replan: bool) -> Tuple[PlanActionResult, str]:
        """This action is predefined. We run the motion validation"""

        try:
            parameterized_py_function, out_env_hash = self._motion_planning_for_skill(replan=replan)
        except MotionSimulationException as e:
            # if not replan and py_function_file.is_file():
            #     print(e.message)
            #     raise RuntimeError("Motion Validation should not fail here. It did work earlier")

            if e.code == MotionSimulationResponseCode.PDDL_PY_TRANSLATION:
                # translation from pddl to python failed. Reinvoke
                raise e

            elif e.code == MotionSimulationResponseCode.EFFECT_FAILED:
                # must be handled by parent
                action = self.domain.get_operator_by_id(self.operator_id)
                parameterized_py_function = parameterize_skill(
                    skill=self.py_function, sas_action=self.sas_action, action=action
                )
                self.motion_validator.set_env_hash(self.env_hash)
                init_state = self.motion_validator.get_predicates_evaluation(self.domain, self.problem)
                _, error_reason = self.llm_interface.determine_reason_for_motion_failure(
                    # we pass the full init state, so the reasoner has maximal observance
                    init_state=init_state,
                    domain=self.domain.remove_for_level(self.parent_operators),
                    domain_knowledge=self.storage.domain_knowledge,
                    sas_action=self.sas_action,
                    chat=self.parent_chat,
                    current_py_function=parameterized_py_function,
                    plan_result=self.plan_result,
                    e=e,
                    out_dir=self.out_dir,
                )
                if self.interactive:
                    input("Continue?")
                logger.info("Motion Validation error reason")
                e = e.copy_with(reason=error_reason)
                raise e

                # if error_reason.occurred_within_chat_versions(chat_v1=init_chat, chat_v2=chat):
                #     last_error = e
                #     # motion validation failed. try to replan with exception and if nothing changes escalate
                #     exception = (
                #         "%s\n\nIt seem you chose the wrong python skill for the pddl action. Correct your choice. If you're still convinced of your choice, just return it again."
                #         % (e.to_observation_message(And(*problem.init), sas_action))
                #     )
                #     chat = chat.add_message(UserMessage([TextMessageContent(text=exception)]))
                # else:
                #     raise e

            else:
                raise NotImplementedError()

        result = PlanLeafActionResult(py_function=parameterized_py_function)

        return result, out_env_hash
