import logging
from pathlib import Path
from typing import List, Tuple, TypeVar

from llm_utils.textgen_api.textgen_api import TextGenApi
from pddl.core import And

from tp_lodge.motion_planning.local_motion_validator import PDDLDomain, PDDLProblem
from tp_lodge.motion_planning.motion_validator import (
    MotionSimulationException,
    MotionSimulationResponseCode,
    MotionValidator,
    get_aligned_effects,
)
from tp_lodge.task_planning.models.planning.plan_result import (
    PlanActionResult,
)
from tp_lodge.task_planning.models.sas.sas_plan import SasPlan
from tp_lodge.task_planning.pddl_generators.pddl_llm_interface import PDDLLLMInterface
from tp_lodge.task_planning.pddl_planner.ai_planner.ai_validator import AIValidator
from tp_lodge.task_planning.pddl_planner.ai_planner.fastdownward_ai_planner import FastDownwardAIPlanner
from tp_lodge.task_planning.pddl_planner.ai_state_transition_retriever import AIStateTransitionRetriever
from tp_lodge.task_planning.pddl_planner.hi_planner.mixins.planner_mixin import PlannerMixin
from tp_lodge.task_planning.pddl_planner.hi_planner.shared_action_node_storage import SharedActionNodeStorage
from tp_lodge.utils.path_utils import get_prompts_dir
from tp_lodge.utils.pddl_utils import filter_formula_by_predicates, get_matching_predicates, is_predicates_subset

T = TypeVar("T", bound="ActionNode")

logger = logging.getLogger(__name__)


class ActionNode(PlannerMixin):

    def __init__(
        self,
        operator_id: str,
        parent_operators: List[str],
        storage: SharedActionNodeStorage,
        motion_validator: MotionValidator,
        textgen_api: TextGenApi,
        out_dir: Path,
    ):
        self.operator_id = operator_id
        self.parent_operators = parent_operators
        self.textgen_api = textgen_api
        self.out_dir = out_dir
        self.interactive = False

        prompts_dir = get_prompts_dir()
        assert prompts_dir.is_dir()

        ai_planner = FastDownwardAIPlanner(**storage.ai_planner_kwargs)
        ai_validator = AIValidator()
        llm_interface = PDDLLLMInterface(prompts_dir=prompts_dir, textgen_api=textgen_api)
        super().__init__(
            motion_validator=motion_validator,
            llm_interface=llm_interface,
            ai_validator=ai_validator,
            ai_planner=ai_planner,
            storage=storage,
        )

    def delete_env_hashes(self):
        for file in self.out_dir.rglob("*.hash"):
            file.unlink()

    def check_did_fulfill_goal(
        self,
        problem: PDDLProblem,
        domain: PDDLDomain,
        sas_plan: SasPlan,
        env_hash: str,
        *,
        subset_check: bool = True,
    ) -> Tuple[PlanActionResult, str]:
        """checks whether the problem goal state is reached by the plan

        the goal state must be a subset of the state after computing the state transitions induced by the plan
        """
        assert self.motion_validator.get_env_hash() == env_hash, "env_hash mismatch"
        assert domain.has_unique_names(), "domain must have unique names"

        supported_preds = self.motion_validator.get_supported_predicates(domain=domain)
        if len(sas_plan.actions) == 0:
            # if we have no actions, we check that there is no state change of the predicates that exist in simulation
            # exp_goal_state = problem.goal_state
            # exp_goal_state = filter_formula_by_predicates(exp_goal_state, supported_preds)
            gt_goal_state = self.motion_validator.get_predicates_evaluation(domain=domain, problem=problem)
        else:
            gt_goal_state = AIStateTransitionRetriever().retrieve_goal_state(
                domain=domain.to_pddl(), problem=problem.to_pddl(), actions=sas_plan.to_string()
            )
        # gt_goal_state = self.motion_validator.get_predicates_evaluation(domain=domain, problem=problem)

        if subset_check:
            # ignore predicates not in the simulation since they might not be in `exp_goal_state`
            exp_goal_state = filter_formula_by_predicates(problem.goal_state, supported_preds)
            if not is_predicates_subset(exp_goal_state, gt_goal_state):
                # filter `gt_post` to contain predicates listed in `problem.goal_state`
                aligned_gt_post = And(*get_matching_predicates(pred_a=problem.goal_state, pred_b=gt_goal_state))
                logger.info(
                    f"Expected effects do not match ground truth effects\n\nExp: {problem.goal_state}\n\nGT: {aligned_gt_post}"
                )
                raise MotionSimulationException(
                    expected=str(problem.goal_state),
                    ground_truth=str(aligned_gt_post),
                    message="The plan did not fulfill the goal state.",
                    code=MotionSimulationResponseCode.UNMET_GOAL,
                )
        else:
            exp_effects, gt_effects = get_aligned_effects(
                gt_prior=problem.initial_state,
                gt_post=gt_goal_state,
                exp_prior=problem.initial_state,
                exp_post=problem.goal_state,
                supported_predicates=self.motion_validator.get_supported_predicates(domain=domain),
            )

            if exp_effects != gt_effects:
                raise RuntimeError(
                    f"Expected effects do not match ground truth effects\n\nExp: {exp_effects}\n\nGT: {gt_effects}"
                )
