import json
import logging
import re
from pathlib import Path
from textwrap import indent
from typing import List, Optional, Tuple

from llm_utils import Chat, TextGenApi, TextMessageContent, UserMessage, AssistantMessage
from python_utils.string_utils import extract_first_skill_list
from pddl.core import And, Formula

from tp_lodge.motion_planning.motion_validator import (
    MotionSimulationErrorReason,
    MotionSimulationErrorType,
    MotionSimulationException,
    MotionSimulationResponseCode,
    MotionValidator,
)
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_operator import PDDLOperator
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.task_planning.models.planning.plan_result import (
    PlanActionResult,
    PlanCompositeActionResult,
    PlanResult,
)
from tp_lodge.task_planning.models.sas.sas_action import SasAction
from tp_lodge.task_planning.models.sas.sas_plan import SasPlan
from tp_lodge.task_planning.pddl_generators.pddl_llm_interface import ChatFactory
from tp_lodge.task_planning.pddl_planner.ai_state_transition_retriever import AIStateTransitionRetriever
from tp_lodge.task_planning.pddl_planner.hi_planner.action_node import ActionNode
from tp_lodge.task_planning.pddl_planner.hi_planner.mixins.gen_definitions_mixin import verify_retries
from tp_lodge.task_planning.pddl_planner.hi_planner.pddl_composite_action_node import PDDLCompositeActionNode
from tp_lodge.task_planning.pddl_planner.hi_planner.pddl_predefined_action_node import PDDLPredefinedActionNode
from tp_lodge.task_planning.pddl_planner.hi_planner.shared_action_node_storage import SharedActionNodeStorage
from tp_lodge.utils.python_parse_utils import get_python_code_from_text, validate_python_code_global_variables

logger = logging.getLogger(__name__)


class PDDLActionNodeWrapper(ActionNode):

    def __init__(
        self,
        operator_id: str,
        parent_operators: List[str],
        storage: SharedActionNodeStorage,
        motion_validator: MotionValidator,
        textgen_api: TextGenApi,
        out_dir: Path,
        #
        parent_chat: Chat,
        plan_result: PlanResult,
        subplan_result: PlanCompositeActionResult,
        sas_action: SasAction,
        domain: PDDLDomain,
        problem: PDDLProblem,
        env_hash: str,
        curr_state: List[Formula],
    ):
        super().__init__(
            operator_id=operator_id,
            parent_operators=parent_operators,
            storage=storage,
            motion_validator=motion_validator,
            textgen_api=textgen_api,
            out_dir=out_dir,
        )
        self.parent_chat = parent_chat
        self.plan_result = plan_result
        self.subplan_result = subplan_result
        self.sas_action = sas_action
        self.domain = domain
        self.problem = problem
        self.env_hash = env_hash
        self.curr_state = curr_state

        # setup
        self.inner_chat = Chat(messages=[])

        self.last_error: Optional[MotionSimulationException] = None
        self.last_selected_options: Optional[List[str]] = None

    @property
    def _this_level_domain(self) -> PDDLDomain:
        return self.domain.remove_for_level(parent_op_ids=self.parent_operators)

    def _check_preconditions(self):
        # check preconditions are valid
        precond_valid, repair_advice = self.ai_validator.validate_plan_executes_successfully(
            domain=str(self._this_level_domain.to_pddl()),
            problem=str(self.problem.copy_with(initial_state=self.curr_state).to_pddl()),
            plan=self.sas_action.to_string(),
        )
        if not precond_valid:
            self.storage.add_env_retry()
            action, explanation = self.llm_interface.get_invalid_precondition_prompt(
                domain_knowledge=self.storage.domain_knowledge,
                domain=self._this_level_domain,
                problem=self.problem,
                chat=self.parent_chat,
                plan_result=self.plan_result,
                action=self.sas_action,
                explanation=repair_advice,
                out_dir=self.out_dir,
            )
            raise MotionSimulationException(
                expected="",
                ground_truth="",
                code=MotionSimulationResponseCode.EFFECT_FAILED,
                message=repair_advice,
                reason=MotionSimulationErrorReason(
                    error_type=MotionSimulationErrorType.PDDL_FIX,
                    explanation=explanation,
                    # pddl_action=action,
                    pddl_operators=[self.sas_action.name],
                ),
            )

    @verify_retries
    def get_skill_sequence(
        self,
        action: PDDLOperator,
        function_stubs: str,
        out_dir: Path,
        previous_skills: List[str],
        connection: Optional[str] = None,
        replan: bool = False,
        cache: bool = True,
        skill_specific: bool = False,
    ) -> List[str]:
        decide_whether_primitive_file = out_dir / "decide-whether-primitive.json"
        chat_file = out_dir / ("%s.chat" % decide_whether_primitive_file.stem)
        chat_raw_file = out_dir / ("%s-chat.json" % decide_whether_primitive_file.stem)

        concatted_chat = (
            Chat.concat_chats(self.parent_chat, self.inner_chat)
            if not self.storage.flatten_chat_history
            else self.inner_chat
        )
        if not replan:
            concatted_chat = self.llm_interface.get_decide_whether_primitive_prompt(
                sas_action=self.sas_action,
                operator=action,
                predicates=self._this_level_domain.predicates,
                problem=self.state_problem,
                function_stubs=function_stubs,
                chat=concatted_chat,
                previous_skills=previous_skills,
                skill_specific=skill_specific,
            )

        if not self.storage.flatten_chat_history:
            self.inner_chat = concatted_chat.get_all_after(prior_chat=self.parent_chat)
        else:
            self.inner_chat = concatted_chat

        if cache:
            # if decide_whether_primitive_file.is_file():
            #     # chat = ChatFactory().from_json(json.loads(chat_raw_file.read_text())).cap_after_first_assistant_message()
            #     # self.inner_chat = self.inner_chat.add_message(chat.messages[-1])
            #     data = json.loads(decide_whether_primitive_file.read_text())
            #     self.inner_chat = self.inner_chat.add_message(AssistantMessage([TextMessageContent(f"```python\n{data['func_calls']}\n```")]))
            #     return data["func_calls"]
            # else:
                op = self.domain.get_operator_by_id(self.operator_id)
                if op.verified and op.mapped_skill_sequence is not None and len(op.mapped_skill_sequence) == 1:
                    # only do for primitive skills
                    self.inner_chat = self.inner_chat.add_message(AssistantMessage([TextMessageContent(f"```python\n{op.mapped_skill_sequence}\n```")]))
                    return op.mapped_skill_sequence

        response = self.textgen_api.do_call(concatted_chat, connection_id=connection, call_id="decide-whether-primitive")

        assert len(response.content) == 1 and isinstance(response.content[0], TextMessageContent)

        text = response.content[0].text

        self.inner_chat = self.inner_chat.add_message(response)
        chat_file.write_text(str(self.inner_chat))

        kwargs = dict(
            action=action,
            function_stubs=function_stubs,
            out_dir=out_dir,
            previous_skills=previous_skills,
            connection=connection,
            replan=True,
            cache=False,
            skill_specific=skill_specific,
        )

        skill_mapping_section = re.search(r"# Skill Mapping(.*)(?:\[|$)", text, re.DOTALL)

        if skill_mapping_section is None:
            self.inner_chat = self.inner_chat.add_user_text("Could not find # Skill Mapping section in response.")
            self.storage.add_val_retry()
            return self.get_skill_sequence(**kwargs)

        skill_mapping = skill_mapping_section.group(1)

        try:
            code = "\n".join(extract_first_skill_list(skill_mapping))
            function_defs = re.findall("def ([\w\-]+)", function_stubs)
            func_calls = validate_python_code_global_variables(
                code=code, variables=action.param_names(), func_names=function_defs
            )
        except ValueError as e:
            logger.info("Errors in code found:\n%s" % indent(str(e), " " * 4))

            self.inner_chat = self.llm_interface.inject_decide_whether_primitive_reprompt(
                chat=self.inner_chat, message=str(e)
            )
            self.storage.add_val_retry()
            return self.get_skill_sequence(**kwargs)

        decide_whether_primitive_file.write_text(json.dumps({"func_calls": func_calls}))
        chat_raw_file.write_text(json.dumps(ChatFactory().to_json(self.inner_chat)))

        logger.info("Decomposition suggested:\n%s" % indent("\n".join(func_calls), prefix=" " * 4))

        return func_calls

    @verify_retries
    def _decide_whether_predefined(
        self,
        action: PDDLOperator,
        replan: bool,
    ) -> Tuple[Optional[PlanActionResult], str]:
        self.subplan_result.unspecify_action(sas_action=self.sas_action)
        # decide whether the action is a primitive or not
        function_calls = self.get_skill_sequence(
            action=action,
            function_stubs=self.storage.function_stubs,
            out_dir=self.out_dir,
            previous_skills=self.plan_result.get_flattened_skills(),
            replan=self.last_error is not None,
            cache=self.last_error is None and not replan,
        )

        if self.last_selected_options is not None and self.last_selected_options == function_calls:
            assert self.last_error is not None
            logger.info("Replanning produced the same result for whether the action is a primitive. Escalating replan")
            raise self.last_error.copy_with(
                reason=MotionSimulationErrorReason(
                    explanation=self.last_error.to_observation_message(
                        init_state=And(*self.state_problem.initial_state), sas_action=self.sas_action.to_string()
                    ),
                    error_type=MotionSimulationErrorType.PDDL_FIX,
                    pddl_operators=[self.sas_action.name],
                )
            )

        last_selected_options = function_calls

        try:
            if len(function_calls) == 1:
                # we don't change anything if `replan` is active, this will automatically be detected by the motion simulation
                child_node = PDDLPredefinedActionNode(
                    operator_id=self.operator_id,
                    parent_operators=self.parent_operators,
                    storage=self.storage,
                    motion_validator=self.motion_validator,
                    textgen_api=self.textgen_api,
                    out_dir=self.out_dir,
                    sas_action=self.sas_action,
                    py_function=function_calls[0],
                    parent_chat=Chat.concat_chats(self.parent_chat, self.inner_chat),
                    plan_result=self.plan_result,
                    env_hash=self.env_hash,
                    problem=self.state_problem,
                    domain=self.domain,
                )
            elif len(function_calls) == 0:
                logger.warning("Action resolved to empty skill: %s" % self.sas_action.to_string())
                self.check_did_fulfill_goal(
                    problem=self.state_problem, domain=self._this_level_domain, sas_plan=SasPlan(actions=[]), env_hash=self.env_hash
                )
                action = self.domain.get_operator_by_id(self.operator_id)
                action.verified = True
                action.mapped_skill_sequence = []
                return None, self.env_hash
            else:
                child_node = PDDLCompositeActionNode(
                    operator_id=self.operator_id,
                    parent_operators=self.parent_operators + [self.operator_id],
                    storage=self.storage,
                    motion_validator=self.motion_validator,
                    textgen_api=self.textgen_api,
                    out_dir=self.out_dir,
                    sas_action=self.sas_action,
                    py_functions=function_calls,
                    parent_chat=Chat.concat_chats(self.parent_chat, self.inner_chat),
                    plan_result=self.plan_result,
                    env_hash=self.env_hash,
                    problem=self.state_problem,
                    domain=self.domain,
                )

            plan_action_result, env_hash = child_node.infer(replan=replan)
            self.domain.get_operator_by_id(self.operator_id).verified = True
            return plan_action_result, env_hash

        except MotionSimulationException as e:
            self.last_error = e
            self.last_selected_options = last_selected_options
            if e.code == MotionSimulationResponseCode.PDDL_PY_TRANSLATION:
                # translation from pddl to python failed. Reinvoke
                logger.info("Python Skill execution failed?!")
                self.inner_chat = self.inner_chat.add_message(
                    UserMessage(
                        content=[
                            TextMessageContent(
                                'Executing the python action returns following exception: "%s"' % e.message
                            )
                        ]
                    )
                )
                self.storage.add_val_retry()
                return self._decide_whether_predefined(action=action, replan=replan)

            elif e.code == MotionSimulationResponseCode.UNMET_GOAL:
                logger.info("Motion validation failed (unmet-goal). Reprompt LLM")

                self.inner_chat = self.llm_interface.get_unmet_goal_prompt(
                    chat=self.inner_chat,
                    plan=SasPlan(actions=[]),
                    exp_goal=e.expected,
                    gt_goal=e.ground_truth,
                )
                logger.debug(self.inner_chat.last_message())

                self.storage.add_val_retry()
                return self._decide_whether_predefined(action=action, replan=replan)

            elif e.code == MotionSimulationResponseCode.EFFECT_FAILED:
                assert e.reason is not None
                e = e.copy_with(
                    message=e.to_observation_message(
                        And(*self.state_problem.initial_state), self.sas_action.to_string()
                    )
                )
                if e.reason.is_translation_reason:
                    # motion validation failed. try to replan with exception and if nothing changes escalate
                    exception = (
                        # "%s\n\nIt seem you decided wrongly about whether the action can be implemented with one predefined skill. Rethink your choice."
                        "%s\n\nThe skill you previously decided on did not work. Here is a explanation what could fix it. Rethink your choice."
                        % e.reason.explanation
                    )
                    self.inner_chat = self.inner_chat.add_message(UserMessage([TextMessageContent(text=exception)]))
                    self.storage.add_val_retry()
                    return self._decide_whether_predefined(action=action, replan=replan)
                else:
                    raise e

            else:
                raise NotImplementedError()

    def infer(self, replan: bool) -> Tuple[Optional[PlanActionResult], str]:
        # validate preconditions
        self._check_preconditions()

        # get sub-problem for this sas_action
        self.state_problem = AIStateTransitionRetriever().retrieve_single(
            domain=self._this_level_domain,
            objects=self.problem.objects,
            current_state=self.curr_state,
            action=self.sas_action.to_string(),
            effects_for_goal=True,  # only get effects do not blow up prompts
        )

        self.out_dir.mkdir(exist_ok=True)

        # if should_replan:
        #     replan = True

        action = self.domain.get_operator_by_id(self.operator_id)
        assert action.name == self.sas_action.name, "Operator name mismatch: %s != %s" % (action.name, self.sas_action.name)

        plan_action_result, env_hash = self._decide_whether_predefined(action=action, replan=replan)

        return plan_action_result, env_hash
