import logging
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import List, Optional, Callable

from llm_utils.openai_api.chat import Chat
from pddl.core import And, Formula

from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.utils.pddl_utils import (
    filter_formula_by_predicates,
    get_aligned_effects,
    get_valid_predicates,
    get_list_of_predicates,
    is_predicates_subset,
    pred_len,
)

logger = logging.getLogger(__name__)


class MotionSimulationResponseCode(str, Enum):
    SUCCESS = "none"
    EFFECT_FAILED = "effect_failed"
    UNMET_GOAL = "unmet_goal"
    PDDL_PY_TRANSLATION = "pddl_py_translation"


class MotionSimulationErrorType(str, Enum):
    PDDL_FIX = "pddl-fix"
    MULTIPLE_SKILLS = "multiple-skills"
    DIFFERENT_SKILL = "different-skill"


@dataclass
class MotionSimulationErrorReason:
    explanation: str
    pddl_operators: List[str]
    error_type: MotionSimulationErrorType

    @property
    def is_translation_reason(self) -> bool:
        return self.error_type != MotionSimulationErrorType.PDDL_FIX

    def occurred_in_chat(self, chat: Chat) -> bool:
        return any([m.content[0].text == self.message_that_introduced_error for m in chat.messages])

    def occurred_within_chat_versions(self, chat_v1: Chat, chat_v2: Chat) -> bool:
        assert any([m.content[0].text == self.message_that_introduced_error for m in chat_v2.messages])
        if not all([m in chat_v2.messages for m in chat_v1.messages]):
            logger.warning("chat_v1 must be subset of chat_v2")

        messages_to_check = [m for m in chat_v2.messages if m not in chat_v1.messages]
        for message_to_check in messages_to_check:
            if message_to_check.content[0].text == self.message_that_introduced_error:
                return True
        return False


@dataclass
class MotionSimulationException(Exception):
    expected: str
    ground_truth: str
    code: MotionSimulationResponseCode
    message: Optional[str] = None
    reason: Optional[MotionSimulationErrorReason] = None
    replan_plan: bool = True

    def copy_with(
        self, *, reason: Optional[MotionSimulationErrorReason] = None, message: Optional[str] = None
    ) -> "MotionSimulationException":
        return MotionSimulationException(
            expected=self.expected,
            ground_truth=self.ground_truth,
            code=self.code,
            reason=reason or self.reason,
            message=message or self.message,
            replan_plan=self.replan_plan,
        )

    def to_observation_message(self, init_state: And, sas_action: str, *, only_valid_init: bool = True) -> str:
        base_text = """I validated the action in a simulated environment.
Before executing the action, the simulated environment was in the following state: '%s'

The planner then executed the action: '%s'""" % (
            str(And(*get_valid_predicates(init_state)) if only_valid_init else init_state),
            sas_action,
        )
        if self.message is not None:
            return """%s\nThe execution failed with: %s""" % (base_text, self.message)

        else:
            return """%s

However, the observed effect in the simulation differed from the effect expected by the planner:
Expected Effect: '%s'
Observed Effect:   '%s'""" % (
                base_text,
                self.expected,
                self.ground_truth,
            )


class MotionValidator:

    def validate(
        self,
        motion: str,
        problem: PDDLProblem,
        domain: PDDLDomain,
        env_hash: str,
        out_dir: Path,
        post_run_motion_callback: Optional[Callable[[str], None]] = None,
        post_env_hash: Optional[str] = None,
    ) -> str:
        logger.info("executing %s in simulation" % (motion.strip()))
        # reset env to appropriate state
        self.set_env_hash(env_hash)
        supported_predicates = self.get_supported_predicates(domain=domain)
        gt_prior_predicates = self.get_predicates_evaluation(domain=domain, problem=problem)

        if supported_predicates is not None:
            supported_init = get_list_of_predicates(filter_formula_by_predicates(
                And(*problem.initial_state), known_predicates=supported_predicates
            ))
        else:
            supported_init = problem.initial_state

        if not is_predicates_subset(supported_init, gt_prior_predicates):
            raise RuntimeError("predicates invalid, but this should never happen")

        if post_env_hash is not None:
            self.set_env_hash(post_env_hash)
        else:
            self._run_motion(motion=motion)

        # get current env state
        env_hash = self.get_env_hash()
        if post_env_hash is None:
            self.set_env_hash(env_hash) # FIXME: hack. we figured that after resetting the poses are slightly different than after the initial execution. Not validated yet

        if post_run_motion_callback is not None:
            post_run_motion_callback(env_hash)

        gt_post_predicates = self.get_predicates_evaluation(domain=domain, problem=problem)

        exp_effects, gt_effects = get_aligned_effects(
            gt_prior=gt_prior_predicates,
            exp_prior=problem.initial_state,
            gt_post=gt_post_predicates,
            exp_post=problem.goal_state_list,
            supported_predicates=supported_predicates,
        )
        assert pred_len(exp_effects) == pred_len(gt_effects)

        if not gt_effects == exp_effects:
            logger.info(
                "Motion Validation Failed\nExpected:     %s\nGround Truth: %s" % (str(exp_effects), str(gt_effects))
            )

            # while the effects should be enough, the LLM struggles to see that leaving out the predicate means it did not change.
            # Therefore, we also add the predicates that deviate between the effects but did not change

            raise MotionSimulationException(
                expected=str(exp_effects),
                ground_truth=str(gt_effects),
                code=MotionSimulationResponseCode.EFFECT_FAILED,
            )

        return env_hash

    @abstractmethod
    def _run_motion(self, motion: str):
        raise NotImplementedError()

    @abstractmethod
    def get_predicates_evaluation(self, domain: PDDLDomain, problem: PDDLProblem) -> List[Formula]:
        raise NotImplementedError()

    @abstractmethod
    def validate_predicates(self, predicates: List[Formula], domain: PDDLDomain, problem: PDDLProblem):
        raise NotImplementedError()

    @abstractmethod
    def get_env_hash(self) -> str:
        raise NotImplementedError()

    @abstractmethod
    def init_state(self, seed: int):
        raise NotImplementedError()

    @abstractmethod
    def set_env_hash(self, hash: str):
        raise NotImplementedError()

    @abstractmethod
    def inject_init_predicates(self, domain: PDDLDomain, problem: PDDLProblem):
        raise NotImplementedError()

    @abstractmethod
    def get_supported_predicates(self, domain: PDDLDomain) -> Optional[List[str]]:
        raise NotImplementedError()
