import logging
from pathlib import Path
from textwrap import indent
from typing import List, Optional, Tuple

from llm_utils.openai_api.chat import Chat
from llm_utils.openai_api.text_message_content import TextMessageContent
from llm_utils.openai_api.user_message import UserMessage
from pddl.exceptions import PDDLValidationError

from tp_lodge.motion_planning.motion_validator import MotionValidator
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_operator import PDDLOperator
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.task_planning.models.sas.sas_action import SasAction
from tp_lodge.task_planning.pddl_generators.pddl_llm_interface import PDDLLLMInterface
from tp_lodge.task_planning.pddl_planner.ai_planner.ai_planner import AIPlanner
from tp_lodge.task_planning.pddl_planner.ai_planner.ai_validator import AIValidator
from tp_lodge.task_planning.pddl_planner.hi_planner.out_of_retries_error import OutOfRetriesException
from tp_lodge.task_planning.pddl_planner.hi_planner.shared_action_node_storage import SharedActionNodeStorage

logger = logging.getLogger(__name__)


def verify_retries(f):
    def wrapper(*args, **kwargs):
        instance = args[0]
        assert isinstance(
            instance, GenDefinitionsMixin
        ), "This decorator should only be used with GenDefinitionsMixin instances"
        if instance.storage.n_val_retries <= 0:
            raise OutOfRetriesException("Max retries reached")
        if instance.storage.n_env_retries <= 0:
            raise OutOfRetriesException("Max environment retries reached")
        result = f(*args, **kwargs)
        return result

    return wrapper


class GenDefinitionsMixin:
    def __init__(
        self,
        motion_validator: MotionValidator,
        llm_interface: PDDLLLMInterface,
        ai_validator: AIValidator,
        ai_planner: AIPlanner,
        storage: SharedActionNodeStorage,
    ):
        self.motion_validator = motion_validator
        self.ai_validator = ai_validator
        self.ai_planner = ai_planner
        self.llm_interface = llm_interface
        self.storage = storage

    def regenerate_pddl_definitions(
        self,
        *,
        domain: PDDLDomain,
        problem: PDDLProblem,
        domain_knowledge: str,
        exception: str,
        chat: Optional[Chat],
        chat_file: Path,
        out_dir: Path,
    ):
        chat = self.llm_interface.get_regenerate_pddl_files_prompt(
            domain=domain, problem=problem, domain_knowledge=domain_knowledge, exception=exception, chat=chat
        )

        res_chat, new_problem, new_domain = self._generate_pddl_files(
            domain=domain, problem=problem, chat=chat, chat_file=chat_file
        )
        return new_problem, new_domain, res_chat

    def generate_domain(
        self,
        *,
        chat: Optional[Chat],
        sas_action: SasAction,
        action: PDDLOperator,
        problem: PDDLProblem,
        function_stubs: str,
        function_calls: List[str],
        domain: PDDLDomain,
        out_dir: Path,
        add_chat_as_message: bool = False,
        hide_problem_goal: bool = False,
    ):
        chat_file = out_dir / "pddl-action-breakdown-prompt.chat"
        chat = self.llm_interface.get_prompt_to_generate_domain(
            chat=chat,
            sas_action=sas_action,
            action=action,
            problem=problem,
            function_stubs=function_stubs,
            function_calls=function_calls,
            domain=domain,
            add_chat_as_message=add_chat_as_message,
            hide_problem_goal=hide_problem_goal,
        )

        # TODO: also pass problem if available
        res_chat, problem, domain = self._generate_pddl_files(
            problem=problem, domain=domain, chat=chat, chat_file=chat_file
        )

        return problem, domain, res_chat

    def generate_problem(
        self,
        *,
        instruction: str,
        domain_skeleton: PDDLDomain,
        domain_knowledge: Optional[str],
        problem_skeleton: PDDLProblem,
        out_dir: Path,
    ):
        chat_file = out_dir / "pddl-problem-gen-prompt.chat"
        chat = self.llm_interface.get_prompt_to_generate_problem_from_text(
            instruction=instruction,
            domain_skeleton=domain_skeleton,
            domain_knowledge=domain_knowledge,
            problem_skeleton=problem_skeleton,
        )

        res_chat, problem_txt, domain_txt = self._generate_pddl_files(
            domain=domain_skeleton, problem=problem_skeleton, chat=chat, chat_file=chat_file, parse_domain=False
        )

        return problem_txt, domain_txt, res_chat

    def generate_definitions_from_text(
        self,
        *,
        instruction: str,
        function_stubs: str,
        domain_skeleton: PDDLDomain,
        domain_knowledge: Optional[str],
        problem_skeleton: PDDLProblem,
        predicates_description: Optional[str],
        out_dir: Path,
        chat: Optional[Chat],
    ):
        chat_file = out_dir / "pddl-problem-domain-gen-prompt.chat"
        chat = self.llm_interface.get_prompt_to_generate_definitions_from_text(
            instruction=instruction,
            function_stubs=function_stubs,
            domain_skeleton=domain_skeleton,
            domain_knowledge=domain_knowledge,
            problem_skeleton=problem_skeleton,
            chat=chat,
        )

        res_chat, problem_txt, domain_txt = self._generate_pddl_files(
            domain=domain_skeleton, problem=problem_skeleton, chat=chat, chat_file=chat_file
        )

        return problem_txt, domain_txt, res_chat

    @verify_retries
    def _generate_pddl_files(
        self, *, domain: PDDLDomain, problem: PDDLProblem, chat: Chat, chat_file: Path, parse_domain: bool = True
    ) -> Tuple[Chat, PDDLProblem, PDDLDomain]:
        supported_predicates = self.motion_validator.get_supported_predicates(domain=domain)

        res_chat, new_problem, new_domain, errors = self.llm_interface.generate_pddl_files(
            domain=domain,
            problem=problem,
            chat=chat,
            chat_file=chat_file,
            supported_predicates=supported_predicates,
            parse_domain=parse_domain,
        )

        if len(errors) == 0:
            try:
                new_domain.to_pddl()
            except PDDLValidationError as e:
                errors.append(ValueError("Domain validation error: %s" % str(e)))

        if len(errors) == 0:
            exception, success = self.ai_validator.validate(
                domain=str(new_domain.to_pddl()),
                problem=str(new_problem.to_pddl(force=True)),
                plan=None,
                options="-v",
            )
            if not success:
                errors.append(ValueError(exception))

        if len(errors) > 0:
            errors_str = "\n".join(list(set([str(e) for e in errors])))
            logger.info("Errors in domain found:\n%s" % indent(errors_str, " " * 4))
            message = (
                "Following syntax and semantic errors were found:\n%s\n\n Fix the errors and use the provided format\n%s"
                % (errors_str, self.llm_interface.pddl_out_format)
            )
            fix_chat = res_chat.add_message(UserMessage([TextMessageContent(message)]))
            self.storage.add_val_retry()
            fix_chat, new_problem, new_domain = self._generate_pddl_files(
                domain=new_domain,
                problem=new_problem,
                chat=fix_chat,
                chat_file=chat_file,
            )

        return res_chat, new_problem, new_domain
