import logging
import sys
from pathlib import Path
from typing import Tuple

from llm_utils.openai_api.chat import Chat
from llm_utils.textgen_api.textgen_api import TextGenApi

from tp_lodge.motion_planning.motion_validator import MotionValidator
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.task_planning.models.planning.plan_result import PlanResult
from tp_lodge.task_planning.pddl_planner.ai_planner.ai_validator import AIValidator
from tp_lodge.task_planning.pddl_planner.hi_planner.composite_action_node import CompositeActionNode
from tp_lodge.task_planning.pddl_planner.hi_planner.shared_action_node_storage import SharedActionNodeStorage
from tp_lodge.utils.pddl_lib_utils import check_consistency

logger = logging.getLogger(__name__)


class NLActionNode(CompositeActionNode):

    def __init__(
        self,
        storage: SharedActionNodeStorage,
        motion_validator: MotionValidator,
        textgen_api: TextGenApi,
        domain: PDDLDomain,
        problem_skeleton: PDDLProblem,
        instruction: str,
        out_dir: Path,
    ):
        domain.mark_not_newly_generated()

        # if len(domain.predicates) > 0: # TODO: should be included at some point i think. Currently only used for ipc
        #     problem_skeleton = motion_validator.inject_init_predicates(domain=domain, problem=problem_skeleton)

        self.problem_skeleton = problem_skeleton
        self.instruction = instruction

        super().__init__(
            operator_id="root",
            parent_operators=["root"],
            storage=storage,
            motion_validator=motion_validator,
            textgen_api=textgen_api,
            out_dir=out_dir,
            catch_all_domain_errors=True,
            parent_chat=Chat(messages=[]),
            plan_result=PlanResult(),
            env_hash=motion_validator.get_env_hash(),
            last_domain=domain,
            # last_problem=self.problem_skeleton,
            last_problem=None,  # we don't want to pass problem_skeleton, since this would overwrite using the generated problem+goal
            inner_chat=Chat(messages=[]),
        )

    def _validate(self, domain: PDDLDomain, problem: PDDLProblem):
        domain_pddl = domain.to_pddl()
        problem_pddl = problem.to_pddl(force=True)
        check_consistency(domain=domain_pddl, problem=problem_pddl)
        response, success = AIValidator().validate(
            domain=str(domain_pddl), problem=str(problem_pddl), plan=None, options="-v"
        )
        assert success, "Domain and problem are not consistent. Please check the PDDL definitions.\n" + response

    def _generate_domain(self, replan: bool) -> Tuple[PDDLProblem, PDDLDomain, Chat]:
        domain_knowledge = self.storage.domain_knowledge if self.storage.use_domain_knowledge else None
        domain = self.last_domain.remove_for_level(self.parent_operators)
        # if not replan and len(domain.operators) > 0:
        if False:  # don't use for paper, rather as speedup for demos
            # only generate problem
            problem, domain, chat = self.generate_problem(
                instruction=self.instruction,
                problem_skeleton=self.problem_skeleton,
                domain_skeleton=domain,
                domain_knowledge=domain_knowledge,
                out_dir=self.out_dir,
            )
        else:
            problem, domain, chat = self.generate_definitions_from_text(
                instruction=self.instruction,
                problem_skeleton=self.last_problem or self.problem_skeleton,
                domain_skeleton=domain,
                domain_knowledge=domain_knowledge,
                predicates_description=None,
                function_stubs=self.storage.function_stubs,
                out_dir=self.out_dir,
                chat=self.inner_chat if len(self.inner_chat.messages) > 0 else None,
            )

        return problem, domain, chat

    def plan(self, stdout_level: int = logging.WARNING):
        # setup logging for both tp_lodge and state_estimation
        for logger_name in ["tp_lodge", "state_estimation"]:
            logger = logging.getLogger(logger_name)
            logger.setLevel(logging.DEBUG)
            formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")

            # Console handler for WARNING and above
            ch = logging.StreamHandler(sys.stdout)
            ch.setFormatter(formatter)
            ch.setLevel(stdout_level)

            # File handler for DEBUG and above
            fh = logging.FileHandler(self.out_dir / "planner.log")
            fh.setFormatter(formatter)
            fh.setLevel(logging.DEBUG)

            # Avoid duplicate logs if setup is called more than once
            logger.handlers.clear()
            logger.addHandler(ch)
            logger.addHandler(fh)

            logger.propagate = False

        # run planner
        subplan_result, goal_state, _ = self._plan(replan_outer=False)
        self.plan_result.inject_composite(composite=subplan_result)

        result = ""
        result += "\n" + "$" * 20 + " Planning Finished " + "$" * 20
        result += "\n" + "\n".join(self.plan_result.get_flattened_skills())
        result += "\n" + "$" * 20 + " Planning Finished " + "$" * 20

        logger.info(result)

        return self.plan_result, self.last_domain, goal_state
