from typing import Optional
from pddl.core import Problem

from tp_lodge.task_planning.models.sas.sas_plan import SasPlan
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.pddl_planner.ai_state_transition_retriever import AIStateTransitionRetriever
from tp_lodge.utils.pddl_utils import filter_formula_by_predicates, combine_predicates, get_predicate_evaluation, get_list_of_predicates
from tp_lodge.utils.pddl_lib_utils import copy_problem_w_args

from .fastdownward_ai_planner import FastDownwardAIPlanner


class HierarchicalAIPlanner:

    def __init__(self) -> None:
        self.ai_planner = FastDownwardAIPlanner(alias="lama-first", search_time_limit=30)

    def plan(self, domain: PDDLDomain, problem: Problem) -> Optional[SasPlan]:
        return self._plan_level(domain=domain, problem=problem, parent_op_ids=["root"])

    def _plan_level(self, domain: PDDLDomain, problem: Problem, parent_op_ids: list[str]) -> Optional[SasPlan]:
        domain_level = domain.remove_for_level(parent_op_ids=parent_op_ids)

        unsupported_goal_predicates = get_list_of_predicates(filter_formula_by_predicates(
            problem.goal, [p.name for p in domain_level.predicates], inverse=True
        ))
        if len(unsupported_goal_predicates) > 0:
            predicates = [
                domain.get_predicate(p.name) for p in get_predicate_evaluation(unsupported_goal_predicates).keys()
            ]
            op_ids = list(set([p.parent_operator_id for p in predicates]))
            if len(op_ids) != 1:
                raise NotImplementedError()
            op_hierarchy = list(reversed(domain.get_parent_operator_ids(op_ids[0])))
            if 'root' not in op_hierarchy:
                # operator probably doesnt exist anymore
                return None
            return self._plan_level(domain=domain, problem=problem, parent_op_ids=op_hierarchy)

        assert domain_level.has_unique_names()
        response_str, success, sas_plan = self.ai_planner.plan(
            domain=domain_level.to_pddl(domain_name=problem.domain_name), problem=problem
        )

        if not success:
            # print(response_str)
            return None

        flattened_plan = []
        curr_state = problem.init

        for sas_action in sas_plan.actions:
            action_op = domain_level.get_operator(sas_action.name)

            curr_state_known = filter_formula_by_predicates(curr_state, [p.name for p in domain_level.predicates])
            curr_state_unknown = filter_formula_by_predicates(
                curr_state, [p.name for p in domain_level.predicates], inverse=True
            )
            next_state = (
                AIStateTransitionRetriever()
                .retrieve_single(
                    domain=domain_level,
                    objects=problem.objects,
                    current_state=curr_state_known,
                    action=sas_action.to_string(),
                )
                .goal_state_list
            )
            next_state = combine_predicates(list(next_state), curr_state_unknown)

            if action_op.mapped_skill_sequence is None:
                raise NotImplementedError()
            elif len(action_op.mapped_skill_sequence) == 1:
                flattened_plan.append(sas_action)
                # mapped_skill = action_op.mapped_skill_sequence[0]
                # parameterized_py_function = parameterize_skill(
                #     skill=mapped_skill, sas_action=sas_action, action=action_op
                # )
                # flattened_plan.append(parameterized_py_function)
            else:
                # decompose
                sub_problem = (
                    AIStateTransitionRetriever()
                    .retrieve_single(
                        domain=domain_level,
                        objects=problem.objects,
                        current_state=curr_state_known,
                        action=sas_action.to_string(),
                        effects_for_goal=True,
                    )
                    .to_pddl()
                )
                # we want the entire init state
                sub_problem = copy_problem_w_args(sub_problem, init=curr_state)

                flattened_sub_plan = self._plan_level(
                    domain=domain, problem=sub_problem, parent_op_ids=parent_op_ids + [action_op.id]
                )

                if flattened_sub_plan is None:
                    return None

                flattened_plan.extend(flattened_sub_plan.actions)

            # transition state
            curr_state = next_state

        return SasPlan(actions=flattened_plan)
