import copy
import json
import logging
import time
from pathlib import Path
from textwrap import indent
from typing import List, Optional, Tuple

from llm_utils.openai_api.chat import Chat
from llm_utils.openai_api.chat_factory import ChatFactory
from llm_utils.textgen_api.textgen_api import TextGenApi
from pddl.core import Formula

from tp_lodge.motion_planning.motion_validator import (
    MotionSimulationException,
    MotionSimulationResponseCode,
    MotionValidator,
)
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.task_planning.models.planning.plan_result import PlanActionResult, PlanCompositeActionResult, PlanResult
from tp_lodge.task_planning.models.sas.sas_plan import SasPlan
from tp_lodge.task_planning.pddl_planner.ai_state_transition_retriever import AIStateTransitionRetriever
from tp_lodge.task_planning.pddl_planner.hi_planner.action_node import ActionNode
from tp_lodge.task_planning.pddl_planner.hi_planner.mixins.gen_definitions_mixin import verify_retries
from tp_lodge.task_planning.pddl_planner.hi_planner.mixins.plan_traversal import PlanTraversal
from tp_lodge.task_planning.pddl_planner.hi_planner.shared_action_node_storage import SharedActionNodeStorage
from tp_lodge.utils.pddl_parse_utils import inject_definitions_into_chat

logger = logging.getLogger(__name__)


class CompositeActionNode(ActionNode):

    def __init__(
        self,
        operator_id: str,
        parent_operators: List[str],
        storage: SharedActionNodeStorage,
        motion_validator: MotionValidator,
        textgen_api: TextGenApi,
        out_dir: Path,
        catch_all_domain_errors: bool,
        #
        parent_chat: Chat,
        plan_result: PlanResult,
        env_hash: str,
        last_domain: PDDLDomain,
        last_problem: Optional[PDDLProblem],
        #
        inner_chat: Optional[Chat] = None,
    ):
        super().__init__(
            operator_id=operator_id,
            parent_operators=parent_operators,
            storage=storage,
            motion_validator=motion_validator,
            textgen_api=textgen_api,
            out_dir=out_dir,
        )
        self.last_plan = None
        self.last_exception = None
        self.last_domain = last_domain
        self.last_problem = last_problem
        self.catch_all_domain_errors = catch_all_domain_errors

        self.traversal = PlanTraversal(
            motion_validator=motion_validator,
            textgen_api=textgen_api,
            storage=storage,
        )

        # setup
        self.inner_chat = inner_chat or Chat(messages=[])
        self.env_hash = env_hash

        self.parent_chat = parent_chat
        self.plan_result = plan_result

    def _generate_domain(self, replan: bool) -> Tuple[PDDLProblem, PDDLDomain, Chat]:
        raise NotImplementedError("This method should be implemented in subclasses.")

    def _splitup_high_level_action(self, replan: bool = False) -> Tuple[PDDLProblem, SasPlan]:
        chat_file = self.out_dir / "generated-chat.json"
        # domain_file = self.out_dir / "generated-domain.json"
        assert self.storage.root_plan_dir is not None
        domain_file = self.storage.root_plan_dir / "generated-domain.json"
        initial_domain_file = self.out_dir / "init-domain.json"
        problem_file = self.out_dir / "generated-problem.json"

        # we save the init domain to keep track of the changes we did at this level
        if initial_domain_file.is_file():
            init_domain = PDDLDomain.loads(initial_domain_file.read_text())
        else:
            init_domain = copy.deepcopy(self.last_domain).remove_for_level(self.parent_operators)
            initial_domain_file.write_text(init_domain.dumps())

        if replan or not problem_file.is_file():
            new_problem, new_domain, concatted_chat = self._generate_domain(replan=replan)

            self.last_domain.add_child_domain(child_domain=new_domain, parent_operator_id=self.operator_id)

            if not self.storage.flatten_chat_history:
                self.inner_chat = concatted_chat.get_all_after(prior_chat=self.parent_chat)
            else:
                self.inner_chat = concatted_chat

            domain_file.write_text(self.last_domain.dumps())
            problem_file.write_text(new_problem.dumps())
            chat_file.write_text(json.dumps(ChatFactory().to_json(self.inner_chat)))
        else:
            self.last_domain.update_with(PDDLDomain.loads(domain_file.read_text()))
            # we want to load the problem and later adapt the parent `state_problem` to match this problem
            new_problem = self.last_problem or PDDLProblem.loads(problem_file.read_text())
            self.inner_chat = ChatFactory().from_json(json.loads(chat_file.read_text()))

        # once we generated the domain and problem, we should inject the init state
        new_problem = self.motion_validator.inject_init_predicates(domain=self.last_domain.remove_for_level(self.parent_operators), problem=new_problem)

        if self.storage.flatten_chat_history:
            # flatten chat history to avoid too long context
            self.inner_chat = self.inner_chat.cap_after_first_assistant_message()
            op_change, pred_change = init_domain.get_change_to(self.last_domain.remove_for_level(self.parent_operators))
            self.inner_chat = inject_definitions_into_chat(
                chat=self.inner_chat,
                operator_changes=op_change,
                predicate_changes=pred_change,
                domain=self.last_domain,
                problem=new_problem,
                prev_problem=self.last_problem,
            )

        sas_plan, concatted_chat, domain, problem = self.retry_ai_planning(
            domain=self.last_domain.remove_for_level(
                self.parent_operators
            ),  # we only provide the operators from this level
            problem=new_problem,
            chat=(
                Chat.concat_chats(self.parent_chat, self.inner_chat)
                if not self.storage.flatten_chat_history
                else self.inner_chat
            ),
            out_dir=self.out_dir,
            chat_file=chat_file,
            last_plan=self.last_plan,
            last_exception=self.last_exception.message if self.last_exception is not None else None,
            replan_plan=self.last_exception.replan_plan if self.last_exception is not None else False,
            domain_knowledge=self.storage.domain_knowledge,
            domain_changed=not replan,
        )
        self.last_domain.add_child_domain(child_domain=domain, parent_operator_id=self.operator_id)

        if not self.storage.flatten_chat_history:
            self.inner_chat = concatted_chat.get_all_after(prior_chat=self.parent_chat)
        else:
            self.inner_chat = concatted_chat

        if self.storage.flatten_chat_history:
            # flatten chat history to avoid too long context
            op_change, pred_change = init_domain.get_change_to(self.last_domain.remove_for_level(self.parent_operators))
            self.inner_chat = self.inner_chat.cap_after_first_assistant_message()
            self.inner_chat = inject_definitions_into_chat(
                chat=self.inner_chat,
                operator_changes=op_change,
                predicate_changes=pred_change,
                domain=self.last_domain,
                problem=problem,
                prev_problem=self.last_problem,
            )

        domain_file.write_text(self.last_domain.dumps())
        problem_file.write_text(problem.dumps())
        chat_file.write_text(json.dumps(ChatFactory().to_json(self.inner_chat)))

        assert self.last_domain is not None and problem is not None and sas_plan is not None

        return problem, sas_plan

    @verify_retries
    def _plan(self, replan_outer: bool, replan: bool = False) -> Tuple[PlanCompositeActionResult, List[Formula], str]:
        """this action is a composite. We plan the decomposition"""
        # we could pass the action domain, but since the same action can be very different depending on the context,
        # it is not really helpful
        t1 = time.time()
        self.motion_validator.set_env_hash(hash=self.env_hash)
        problem, sas_plan = self._splitup_high_level_action(replan=replan)
        t2 = time.time()
        logger.info("Planning took %.2f seconds" % (t2 - t1))

        self.last_problem = problem

        if not self.storage.flatten_chat_history:
            self.last_domain.mark_not_newly_generated()

        # validate init state
        try:
            self.motion_validator.validate_predicates(
                predicates=problem.initial_state, domain=self.last_domain, problem=problem
            )
        except MotionSimulationException as e:
            if e.code == MotionSimulationResponseCode.EFFECT_FAILED:
                logger.info("Validating init state failed. Replan with LLM")
                logger.debug(
                    "%s\n%s"
                    % (
                        indent(e.expected, prefix="Expected: "),
                        indent(e.ground_truth, prefix="Observed: "),
                    )
                )
                self.inner_chat = self.llm_interface.concat_init_state_failed_prompt(
                    chat=self.inner_chat, expected=e.expected, observed=e.ground_truth
                )
                self.storage.add_val_retry()
                return self._plan(replan=True, replan_outer=False)
            else:
                raise NotImplementedError()

        t3 = time.time()
        logger.info("Splitup high-level action took %.2f seconds" % (t3 - t2))

        # TODO: set domain (actions and) predicates fixed, so child doesn't overwrite them

        try:
            out_env_hash, subplan_result = self.traversal.traverse(
                domain=self.last_domain,
                problem=problem,
                sas_plan=sas_plan,
                parent_operators=self.parent_operators,
                out_dir=self.out_dir,
                parent_chat=Chat.concat_chats(self.parent_chat, self.inner_chat),
                env_hash=self.env_hash,
                replan=replan or replan_outer,
                plan_result=copy.deepcopy(self.plan_result),
            )

            mapped_skills = [
                s
                for action in sas_plan.actions
                for s in self.last_domain.remove_for_level(self.parent_operators)
                .get_operator(action.name)
                .mapped_skill_sequence
            ]
            try:
                self.last_domain.get_operator_by_id(self.operator_id).mapped_skill_sequence = mapped_skills
            except KeyError:
                # happens for root composite
                pass

            t4 = time.time()
            logger.info("Plan traversal took %.2f seconds" % (t4 - t3))

            self.check_did_fulfill_goal(
                problem=problem,
                domain=self.last_domain.remove_for_level(self.parent_operators),
                sas_plan=sas_plan,
                env_hash=out_env_hash,
                subset_check=True,
            )
        except MotionSimulationException as e:
            if e.code == MotionSimulationResponseCode.EFFECT_FAILED:
                self.last_exception = e
                self.last_plan = sas_plan
                assert e.reason is not None
                assert not e.reason.is_translation_reason

                pddl_actions = [action.name for action in sas_plan.actions]
                affected_operators = list(filter(lambda op: op in pddl_actions, e.reason.pddl_operators))
                if self.catch_all_domain_errors or len(affected_operators) > 0:
                    logger.info("Motion validation failed (effect invalid). Reprompt LLM")
                    self.inner_chat = self.llm_interface.concat_reprompt_message(reason=e.reason, affected_operators=affected_operators, chat=self.inner_chat)

                    self.storage.add_val_retry()
                    return self._plan(replan=True, replan_outer=False)
                else:
                    raise e

            elif e.code == MotionSimulationResponseCode.UNMET_GOAL:
                logger.info("Motion validation failed (unmet-goal). Reprompt LLM")

                self.inner_chat = self.llm_interface.get_unmet_goal_prompt(
                    chat=self.inner_chat,
                    plan=sas_plan,
                    exp_goal=e.expected,
                    gt_goal=e.ground_truth,
                )
                logger.debug(self.inner_chat.last_message())
                self.last_exception = e.copy_with(message=self.inner_chat.last_message())
                self.last_plan = sas_plan

                self.storage.add_val_retry()
                # replan would trigger generation of domain, but we only need to run planner
                return self._plan(replan=True, replan_outer=False)

            raise RuntimeError("???")

        goal_state = AIStateTransitionRetriever().retrieve_goal_state(
            domain=self.last_domain.remove_for_level(self.parent_operators).to_pddl(),
            problem=problem.to_pddl(),
            actions=sas_plan.to_string(),
        )
        return subplan_result, goal_state, out_env_hash

    def infer(self, replan: bool) -> Tuple[PlanActionResult, str]:
        """this action is a composite. We plan the decomposition"""
        # if we remove the actions it only makes sense to also remove the chat history.
        # otherwise, the llm thinks it can use the actions in the chat history
        result, goal_state, out_env_hash = self._plan(replan_outer=replan)

        return result, out_env_hash
