import json
import logging
import re
from pathlib import Path
from typing import List, Optional, Tuple, TypeVar, Dict, Any, Callable

from llm_utils.openai_api.chat import Chat
from llm_utils.openai_api.chat_factory import ChatFactory
from llm_utils.openai_api.text_message_content import TextMessageContent
from llm_utils.prompt_generation.prompt import Prompt
from llm_utils.textgen_api.textgen_api import TextGenApi
from pddl.core import And, Formula
from python_utils.string_utils import get_markup_from_text, remove_comments

from tp_lodge.motion_planning.motion_validator import (
    MotionSimulationErrorReason,
    MotionSimulationErrorType,
    MotionSimulationException,
)
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLPredicate
from tp_lodge.task_planning.models.pddl.pddl_operator import PDDLOperator
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.task_planning.models.planning.plan_result import PlanResult
from tp_lodge.task_planning.models.sas.sas_action import SasAction
from tp_lodge.task_planning.models.sas.sas_plan import SasPlan
from tp_lodge.utils.pddl_parse_utils import (
    get_action_text,
    inject_domain_into_text,
    inject_problem_into_text,
    parse_domain_from_text,
    parse_formula,
    parse_problem_from_text,
)
from tp_lodge.utils.pddl_utils import combine_predicates, get_change_str, get_matching_predicates, get_valid_predicates, get_list_of_predicates
from tp_lodge.utils.llm_json_parser import LLMJsonParser

logger = logging.getLogger(__name__)


class PDDLLLMInterface:

    def __init__(self, prompts_dir: Path, textgen_api: TextGenApi) -> None:
        self.prompts_dir = prompts_dir
        self.textgen_api = textgen_api
        self.pddl_out_format = (self.prompts_dir / "pddl_output_format.md").read_text()
        self.json_parser = LLMJsonParser(textgen_api)

    def __get_f_skill_signatures(self, f_skills: str) -> str:
        function_defs = re.findall("def .*", f_skills)
        return "\n".join(map(lambda f: f"- {f}", function_defs))

    def get_regenerate_pddl_files_prompt(
        self,
        *,
        domain: PDDLDomain,
        problem: PDDLProblem,
        domain_knowledge: str,
        exception: str,
        chat: Optional[Chat],
    ):
        prompt = Prompt.load_from_file(self.prompts_dir / "pddl-fix-planning-failed.xml")
        prompt.replace(needle="{domain_knowledge}", replacement=domain_knowledge)
        prompt.replace(needle="{exception}", replacement=exception)
        prompt.replace(needle="{output_format}", replacement=self.pddl_out_format)
        prompt.replace(needle="{actions_defined_in_domain}", replacement=", ".join([a.name for a in domain.operators]))

        if chat is None:
            chat = prompt.to_chat()
            new_last_message = chat.last_message()
            new_last_message = inject_domain_into_text(domain=domain, text=new_last_message)
            new_last_message = inject_problem_into_text(problem=problem, text=new_last_message)
            chat = chat.replace_last_message(new_last_message)
        else:
            chat = chat.add_message(prompt.to_chat().messages[-1])

        return chat

    def get_prompt_to_generate_domain(
        self,
        *,
        chat: Optional[Chat],
        sas_action: SasAction,
        action: PDDLOperator,
        problem: PDDLProblem,
        function_stubs: str,
        function_calls: List[str],
        domain: PDDLDomain,
        add_chat_as_message: bool = False,
        hide_problem_goal: bool = False,
    ):
        prompt_file = self.prompts_dir / "pddl-action-breakdown-prompt.xml"

        if hide_problem_goal:
            problem = problem.copy_with(goal_state=[])  # remove

        prompt = Prompt.load_from_file(prompt_file)
        # prompt.replace(needle="{domain_knowledge}", replacement=domain_knowledge)
        prompt.replace(needle="{action}", replacement=sas_action.to_string())
        prompt.replace(needle="{action-definition}", replacement=get_action_text(action))
        prompt.replace(needle="{function-stubs}", replacement=self.__get_f_skill_signatures(function_stubs))
        prompt.replace(needle="{decomposition}", replacement="\n".join(function_calls))
        prompt.prompt = inject_problem_into_text(problem=problem, text=prompt.prompt)
        prompt.prompt = inject_domain_into_text(domain=domain, text=prompt.prompt)

        # prompt.replace(needle="{action-name}", replacement=action.name)
        prompt.replace(needle="{output-format}", replacement=self.pddl_out_format)

        if chat is None:
            chat = prompt.to_chat()
        elif add_chat_as_message:
            chat = chat.add_message(prompt.to_chat().messages[-1])

        return chat

    def get_prompt_to_generate_problem_from_text(
        self,
        *,
        instruction: str,
        domain_skeleton: PDDLDomain,
        domain_knowledge: Optional[str],
        problem_skeleton: PDDLProblem,
    ):
        prompt_file = self.prompts_dir / "pddl-problem-gen-prompt.xml"

        prompt = Prompt.load_from_file(prompt_file)
        prompt.replace(needle="{instruction}", replacement=instruction)
        prompt.prompt = inject_problem_into_text(problem=problem_skeleton, text=prompt.prompt)
        prompt.prompt = inject_domain_into_text(domain=domain_skeleton, text=prompt.prompt)
        # must be after prompt.prompt
        prompt.replace(needle="{output_format}", replacement=self.pddl_out_format)

        prompt.replace(needle="{predicates}", replacement="\n".join(map(str, domain_skeleton.predicates)))
        # prompt.replace(needle="{objects}", replacement=problem_skeleton.get_objects_str())
        # prompt.replace(
        #     needle="{initial_state}",
        #     replacement=" ".join(map(str, get_valid_predicates(problem_skeleton.initial_state))),
        # )
        if domain_knowledge is not None:
            prompt.replace(needle="{domain_knowledge}", replacement="### Domain Knowledge\n%s" % domain_knowledge)
        else:
            prompt.replace(needle="{domain_knowledge}", replacement="")

        return prompt.to_chat()

    def get_prompt_to_generate_definitions_from_text(
        self,
        *,
        instruction: str,
        function_stubs: str,
        domain_skeleton: PDDLDomain,
        domain_knowledge: Optional[str],
        problem_skeleton: PDDLProblem,
        chat: Optional[Chat],
    ):
        if len(domain_skeleton.predicates) == 0:
            logger.info("We learn predicates")
            prompt_file = self.prompts_dir / "pddl-problem-domain-gen-prompt-learn-preds.xml"
        else:
            prompt_file = self.prompts_dir / "pddl-problem-domain-gen-prompt.xml"

        prompt = Prompt.load_from_file(prompt_file)
        prompt.replace(needle="{instruction}", replacement=instruction)
        prompt.prompt = inject_problem_into_text(problem=problem_skeleton, text=prompt.prompt)
        prompt.prompt = inject_domain_into_text(domain=domain_skeleton, text=prompt.prompt)
        # must be after prompt.prompt
        prompt.replace(needle="{output_format}", replacement=self.pddl_out_format)

        prompt.replace(needle="{predicates}", replacement="\n".join(map(str, domain_skeleton.predicates)))

        prompt.replace(
            needle="{domain_knowledge}",
            replacement="### Domain Knowledge\n%s" % domain_knowledge if domain_knowledge is not None else "",
        )

        prompt.replace(needle="{function_stubs}", replacement=self.__get_f_skill_signatures(function_stubs))

        new_chat = prompt.to_chat()

        if chat is None:
            chat = new_chat
        else:  # make sure system message is correct
            chat = chat.replace_system_message(new_chat)

        return chat

    def generate_pddl_files(
        self,
        *,
        domain: PDDLDomain,
        problem: PDDLProblem,
        chat: Chat,
        chat_file: Path,
        supported_predicates: Optional[List[str]],
        parse_domain: bool = True,
    ) -> Tuple[Chat, PDDLProblem, PDDLDomain, List[ValueError]]:
        logger.info("Generate pddl problem and domain files with LLM...")

        response = self.textgen_api.do_call(chat, call_id="generate-pddl-files")

        res_chat = chat.add_message(response)

        chat_file.write_text(str(res_chat))

        assert len(response.content) == 1
        assert isinstance(response.content[0], TextMessageContent)

        raw_text = response.content[0].text
        errors = []

        if parse_domain:
            new_domain, new_errors = parse_domain_from_text(domain=domain, text=raw_text)
            errors.extend(new_errors)
        else:
            new_domain = domain
        new_problem, new_errors = parse_problem_from_text(
            problem=problem, existing_domain=new_domain, text=raw_text, supported_predicates=supported_predicates
        )
        errors.extend(new_errors)

        return res_chat, new_problem, new_domain, errors

    def inject_decide_whether_primitive_reprompt(self, chat: Chat, message: str) -> Chat:
        prompt_file = self.prompts_dir / "pddl-decide-whether-primitive-reprompt.xml"
        prompt = Prompt.load_from_file(prompt_file)
        prompt.replace(needle="{message}", replacement=message)

        res_chat = chat.add_message(message=prompt.to_chat().messages[-1])

        return res_chat

    def get_decide_whether_primitive_prompt(
        self,
        sas_action: SasAction,
        operator: PDDLOperator,
        predicates: List[PDDLPredicate],
        problem: PDDLProblem,
        function_stubs: str,
        chat: Chat,
        previous_skills: List[str],
        skill_specific: bool = False,
    ):
        logger.info("Decide whether primitive...")

        use_problem_for_init_and_goal = False
        if use_problem_for_init_and_goal:
            operator = operator.copy_with(precondition=And(*problem.initial_state), effect=problem.goal_state)

        if skill_specific:
            prompt_file = self.prompts_dir / "pddl-decide-whether-primitive-skill-specific.xml"
        else:
            prompt_file = self.prompts_dir / "pddl-decide-whether-primitive-skill-agnostic.xml"

        constants = [f"'{obj.name}'" for obj in problem.objects]

        prompt = Prompt.load_from_file(prompt_file)
        prompt.replace("{predicates}", replacement="\n".join([str(p) for p in predicates]))
        prompt.replace(needle="{robot-skill}", replacement=sas_action.to_string())
        prompt.replace(needle="{action-def}", replacement=get_action_text(operator))
        prompt.replace(needle="{pddl-action-name}", replacement=operator.name)
        prompt.replace(needle="{executed-skills}", replacement="\n".join(previous_skills))
        prompt.replace(
            needle="{last-executed-skill}",
            replacement=previous_skills[-1] if len(previous_skills) > 0 else "; no previous skill",
        )
        prompt.replace(needle="{state}", replacement=str(And(*problem.initial_state)))
        prompt.replace(needle="{variables}", replacement=", ".join(operator.param_names()))
        prompt.replace(needle="{constants}", replacement=", ".join(constants))
        prompt.replace(needle="{objects}", replacement=problem.get_objects_str())
        # prompt.replace(needle="{action-description}", replacement=action_description)
        # prompt.replace(needle="{motions}", replacement=", ".join(re.findall(r"def (.*)\(", function_stubs)))
        prompt.replace(needle="{motions}", replacement=function_stubs)

        return chat.add_message(prompt.to_chat().messages[-1])

    def try_correct_plan(
        self,
        domain: str,
        domain_knowledge: Optional[str],
        problem: str,
        demo_plan: Optional[str],
        last_plan: str,
        last_exception: str,
        out_dir: Path,
        connection: Optional[str] = None,
    ) -> str:
        logger.info("Try correct plan with llm...")
        assert last_exception is not None and last_plan is not None

        if demo_plan is None:
            prompt_file = self.prompts_dir / "pddl-llm-planning-correct-plan.xml"
        else:
            prompt_file = self.prompts_dir / "pddl-llm-planning-correct-plan-with-valid-plan.xml"
        chat_file = out_dir / ("%s.chat" % prompt_file.stem)

        prompt = Prompt.load_from_file(prompt_file)
        prompt.replace(needle="{domain}", replacement=domain)
        prompt.replace(needle="{domain_knowledge}", replacement=domain_knowledge or "")
        prompt.replace(needle="{problem}", replacement=problem)
        prompt.replace(needle="{prev_plan}", replacement=last_plan)
        prompt.replace(needle="{response}", replacement=last_exception)
        prompt.replace(needle="{demo_plan}", replacement=demo_plan or "")

        chat = prompt.to_chat()

        response = self.textgen_api.do_call(chat, connection_id=connection, call_id="try-gen-plan")

        res_chat = chat.add_message(response)

        chat_file.write_text(str(res_chat))

        assert len(response.content) == 1
        assert isinstance(response.content[0], TextMessageContent)

        text = response.content[0].text

        sas_plan = get_markup_from_text(text=text, markup=["plan"])
        assert len(sas_plan) >= 1

        if len(sas_plan) >= 1:
            sas_plan = remove_comments(sas_plan[-1])
            sas_plan = "\n".join(re.findall(r"(\([\w\- ]+\))", sas_plan))
        else:
            sas_plan = ""

        # assert len(sas_plan.splitlines()) > 0

        return sas_plan

    def try_gen_plan(
        self,
        domain: str,
        domain_knowledge: Optional[str],
        problem: str,
        demo_plan: Optional[str],
        exception: Optional[str],
        out_dir: Path,
        connection: Optional[str] = None,
    ) -> str:
        logger.info("Try plan with llm...")

        if demo_plan is not None:
            prompt_file = self.prompts_dir / "pddl-llm-planning-with-valid-plan.xml"
        else:
            prompt_file = self.prompts_dir / "pddl-llm-planning.xml"
        chat_file = out_dir / ("%s.chat" % prompt_file.stem)

        prompt = Prompt.load_from_file(prompt_file)
        prompt.replace(needle="{domain}", replacement=domain)
        prompt.replace(needle="{domain_knowledge}", replacement=domain_knowledge or "")
        prompt.replace(needle="{problem}", replacement=problem)

        prompt.replace(needle="{demo_plan}", replacement=demo_plan or "")

        if exception is not None:
            prompt.replace(needle="{exception}", replacement="Last Exception\n%s\n" % exception)
        else:
            prompt.replace(needle="{exception}", replacement="")

        chat = prompt.to_chat()

        response = self.textgen_api.do_call(chat, connection_id=connection, call_id="try-gen-plan")

        res_chat = chat.add_message(response)

        chat_file.write_text(str(res_chat))

        assert len(response.content) == 1
        assert isinstance(response.content[0], TextMessageContent)

        text = response.content[0].text

        sas_plan = get_markup_from_text(text=text, markup=["plan"])
        # assert len(sas_plan) >= 1
        if len(sas_plan) >= 1:
            sas_plan = remove_comments(sas_plan[-1])
            sas_plan = "\n".join(re.findall(r"(\([\w\- ]+\))", sas_plan))
        else:
            sas_plan = ""

        return sas_plan

    def determine_reason_for_motion_failure(
        self,
        init_state: Formula,
        domain: PDDLDomain,
        sas_action: SasAction,
        current_py_function: str,
        domain_knowledge: str,
        chat: Chat,
        e: MotionSimulationException,
        out_dir: Path,
        plan_result: PlanResult,
        connection: Optional[str] = None,
    ) -> Tuple[Chat, MotionSimulationErrorReason]:
        logger.info("Determine reason for motion failure")

        chat_file = out_dir / "motion-validation-error-reasoning.chat"

        prompt_file = self.prompts_dir / "pddl-error-reasoner-prompt-skill-error.xml"
        prompt = Prompt.load_from_file(prompt_file)
        if e.message is not None:
            # different simulation error message
            assert len(e.expected) == 0 and len(e.ground_truth) == 0
            logger.debug("    Message: %s" % e.message)

            prompt.replace("{content}", e.message)

        else:
            # deviation of exp effects to obs effects
            assert len(e.expected) > 0 and len(e.ground_truth) > 0
            logger.debug("\n    Expected: %s\n    Observed: %s" % (e.expected, e.ground_truth))
            # `filtered_init`  only contains valid predicates, but we also want those that are not valid but in effects listed
            # init_that_match_effect = get_matching_predicates(
            #     parse_formula(e.expected, only_variables=False), init_state, use_false_for_non_existent=True
            # )
            # filtered_init = list(set(combine_predicates(filtered_init, init_that_match_effect)))

            # prompt.replace("{content}", f"Expected Effect: {e.expected}\n\nGround Truth Effect: {e.ground_truth}")
            assert init_state is not None
            prompt.replace(
                "{content}",
                f"Expected Change:\n{get_change_str(init_state, parse_formula(e.expected, only_variables=False))}\n\n"
                f"Ground Truth Change:\n{get_change_str(init_state, parse_formula(e.ground_truth, only_variables=False))}",
            )

        prompt.replace("{domain_knowledge}", domain_knowledge)
        prompt.replace("{actions}", plan_result.to_string())
        prompt.replace("{init-state}", "\n".join(map(str, get_list_of_predicates(init_state))))
        # prompt.replace("{init-state}", "\n".join(map(str, get_valid_predicates(init_state))))
        prompt.replace("{current_operator}", str(domain.get_operator(sas_action.name).definition))
        prompt.replace("{skill}", current_py_function)
        prompt.replace("{sas_action}", sas_action.to_string())
        prompt.replace("{sas_action_name}", sas_action.name)
        prompt_chat = prompt.to_chat()
        first_reprompt, second_reprompt = prompt_chat.messages
        chat = chat.add_message(first_reprompt)

        response = self.textgen_api.do_call(
            chat, connection_id=connection, call_id="determine-reason-for-motion-failure-1"
        )
        chat = chat.add_message(response)
        chat_file.write_text(str(chat))

        assert len(response.content) == 1
        assert isinstance(response.content[0], TextMessageContent)

        error_explanation = response.content[0].text

        new_chat = chat.copy_with()

        new_chat = new_chat.add_message(second_reprompt)

        def validate_motion_failure_response(data: Dict[str, Any]) -> bool:
            return "type_of_fix" in data and "operators" in data

        data, new_chat = self.json_parser.parse_json_response(
            chat=new_chat,
            call_id="determine-reason-for-motion-failure-2",
            validator=validate_motion_failure_response,
            connection=connection,
            chat_file=chat_file,
        )

        proposed_fix = data["type_of_fix"]
        pddl_operators = data["operators"]

        # if any(reason in proposed_fix for reason in ['pddl-fix', 'prior-skills', 'incorrect-instantiation']):
        if any(
            reason in proposed_fix
            for reason in ["pddl-fix", "prior-skills", "incorrect-instantiation", "multiple-skills"]
        ):
            description = "Effects (%s)" % pddl_operators
            action_to_skill_mapping = plan_result.get_mapping_from_action_to_skill()
            action_to_skill_mapping[sas_action.name] = current_py_function
            mapped_pddl_operators = []
            for pddl_op in pddl_operators:
                if pddl_op not in action_to_skill_mapping:
                    for action, skill in action_to_skill_mapping.items():
                        import ast

                        skill_function = ast.parse(skill).body[0].value.func.id
                        if pddl_op.replace("-", "_") == skill_function:
                            pddl_op = action
                mapped_pddl_operators.append(pddl_op)
            pddl_operators = mapped_pddl_operators

            error_type = MotionSimulationErrorType.PDDL_FIX
        elif "multiple-skills" in proposed_fix:
            # the llm generally decides for the wrong response being the cause, since it does not understand the impact of 'predefined'
            # raise NotImplementedError()
            pddl_op = pddl_operators
            description = "Not Predefined"
            error_type = MotionSimulationErrorType.MULTIPLE_SKILLS
        elif "different-skill" in proposed_fix:
            raise NotImplementedError()
            up_to_chat = chat.copy_with(messages=chat.messages[:-2])
            description = "Wrong Skill"
        else:
            raise RuntimeError()

        logger.info("Problem was %s" % description)

        error_reason = MotionSimulationErrorReason(
            explanation=error_explanation, pddl_operators=pddl_operators, error_type=error_type
        )

        return chat, error_reason

    def concat_reprompt_message(self, reason: MotionSimulationErrorReason, affected_operators: List[str], chat: Chat) -> Chat:
        assert reason.error_type == MotionSimulationErrorType.PDDL_FIX

        prompt_file = self.prompts_dir / "pddl-error-reasoner-pddl-fix-feedback.xml"
        prompt = Prompt.load_from_file(prompt_file)
        prompt.replace(needle="{explanation}", replacement=reason.explanation)
        prompt.replace(needle="{action}", replacement=", ".join(affected_operators))

        reprompt_chat = prompt.to_chat()
        assert len(reprompt_chat.messages) == 1

        chat = chat.add_message(reprompt_chat.messages[-1])

        return chat

    def get_invalid_precondition_prompt(
        self,
        domain_knowledge: str,
        domain: PDDLDomain,
        problem: PDDLProblem,
        chat: Chat,
        plan_result: PlanResult,
        action: SasAction,
        explanation: str,
        out_dir: Path,
        connection: Optional[str] = None,
    ):
        logger.info("Get invalid precondition prompt (%s)..." % explanation)
        assert domain.has_unique_names()

        sas_actions_sofar = plan_result.get_last_composite().get_flattened_skills()
        plan_history = re.sub(":(.*)", "", plan_result.to_string())

        prompt_file = self.prompts_dir / "pddl-precondition-error.xml"
        chat_file = out_dir.parent / ("%s-%s.chat" % (out_dir.stem, prompt_file.stem))

        prompt = Prompt.load_from_file(prompt_file)
        prompt.replace(needle="{domain_knowledge}", replacement=domain_knowledge)
        prompt.replace(needle="{initial_state}", replacement=str(And(*get_valid_predicates(problem.initial_state))))
        prompt.replace(needle="{goal_state}", replacement=str(get_valid_predicates(problem.goal_state)))
        prompt.replace(needle="{actions_executed}", replacement="\n".join(sas_actions_sofar))
        prompt.replace(needle="{action}", replacement=action.to_string())
        prompt.replace(needle="{action_definition}", replacement=str(domain.get_operator(action.name).definition))
        available_actions = [a.name for a in domain.operators]
        prompt.replace(needle="{available_actions}", replacement=", ".join(available_actions))
        prompt.replace(needle="{explanation}", replacement=explanation)
        prompt.replace(needle="{plan_history}", replacement=plan_history)

        chat = chat.add_message(prompt.to_chat().messages[-1])

        def validate_precondition_response(data: Dict[str, Any]) -> bool:
            return "action" in data and "explanation" in data

        def process_precondition_response(data: Dict[str, Any]) -> Tuple[Optional[str], str]:
            action = data["action"]
            explanation = data["explanation"]

            if action is None or "none" in action.lower():
                return None, explanation
            # Process action name
            if "(" in action:
                try:
                    action = SasAction.from_string(action).name
                except (ValueError, AssertionError) as e:
                    raise ValueError("Invalid action format: %s" % action)
            if len([a for a in domain.operators if a.name in action]) == 0:
                raise ValueError("Action %s not found in domain operators." % action)

            return action, explanation

        try:
            data, chat = self.json_parser.parse_json_response(
                chat=chat,
                call_id="invalid-precondition",
                validator=validate_precondition_response,
                processor=process_precondition_response,
                connection=connection,
                chat_file=chat_file,
            )
            action, explanation = data
        except (ValueError, AssertionError) as e:
            # Fallback for when validation fails due to domain-specific issues
            logger.error("JSON parsing failed with domain validation: %s" % e)
            raise

        with chat_file.open("w") as f:
            f.write(explanation)

        return action, explanation

    def concat_init_state_failed_prompt(self, chat: Chat, expected: str, observed: str) -> Chat:
        logger.info("Get init state failed prompt...")

        prompt_file = self.prompts_dir / "pddl-init-state-failed.xml"

        prompt = Prompt.load_from_file(prompt_file)
        prompt.replace(needle="{expected}", replacement=expected)
        prompt.replace(needle="{observed}", replacement=observed)
        p_chat = prompt.to_chat()
        assert len(p_chat.messages) == 1

        return chat.add_message(p_chat.messages[0])

    def get_unmet_goal_prompt(self, chat: Chat, plan: SasPlan, exp_goal: str, gt_goal: str) -> Chat:
        logger.info("Get unmet goal prompt...")

        prompt_file = self.prompts_dir / "pddl-unmet-goal-prompt.xml"

        prompt = Prompt.load_from_file(prompt_file)
        prompt.replace(needle="{plan}", replacement=plan.to_string())
        prompt.replace(needle="{expected}", replacement=exp_goal)
        prompt.replace(needle="{observed}", replacement=gt_goal)

        p_chat = prompt.to_chat()
        assert len(p_chat.messages) == 1

        return chat.add_message(p_chat.messages[0])
