import logging
import shutil
from pathlib import Path
from typing import List

from llm_utils import TextGenApi
from llm_utils.openai_api.chat import Chat

from tp_lodge.motion_planning.motion_validator import MotionValidator
from tp_lodge.motion_planning.motion_validator import (
    MotionSimulationException,
)
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.task_planning.models.planning.plan_result import (
    PlanCompositeActionResult,
    PlanResult,
)
from tp_lodge.task_planning.models.sas.sas_plan import SasPlan
from tp_lodge.task_planning.pddl_planner.ai_state_transition_retriever import AIStateTransitionRetriever
from tp_lodge.task_planning.pddl_planner.hi_planner.shared_action_node_storage import SharedActionNodeStorage

logger = logging.getLogger(__name__)


class PlanTraversal:

    def __init__(
        self,
        motion_validator: MotionValidator,
        textgen_api: TextGenApi,
        storage: SharedActionNodeStorage,
    ):
        self.motion_validator = motion_validator
        self.textgen_api = textgen_api
        self.storage = storage

    def traverse(
        self,
        domain: PDDLDomain,
        problem: PDDLProblem,
        parent_operators: List[str],
        sas_plan: SasPlan,
        env_hash: str,
        out_dir: Path,
        replan: bool,
        parent_chat: Chat,
        plan_result: PlanResult,
    ):
        subplan_result = PlanCompositeActionResult()
        plan_result.inject_composite(composite=subplan_result)

        sub_actions_out = out_dir / "sub-actions"
        prev_domain_file = sub_actions_out / "prev-domain.json"
        prev_sas_plan_file = sub_actions_out / "prev-plan"

        # handle replanning logic
        replans = [False for _ in sas_plan.actions]
        # if replan and sub_actions_out.is_dir():
        if sub_actions_out.is_dir():
            logger.debug("Check what to replan in action sequence")
            assert prev_domain_file.is_file()

            prev_domain = PDDLDomain.loads(prev_domain_file.read_text())  # type: PDDLDomain
            prev_plan = SasPlan.from_string(prev_sas_plan_file.read_text())

            sub_actions_out_cache = sub_actions_out.parent / ("%s.cache" % sub_actions_out.name)
            if sub_actions_out_cache.is_dir():
                shutil.rmtree(sub_actions_out_cache)
            shutil.copytree(sub_actions_out, sub_actions_out_cache)

            # remove env hash
            for file in sub_actions_out_cache.rglob("*.hash"):
                file.unlink()

            domain_for_level = domain.remove_for_level(parent_op_ids=parent_operators)
            prev_domain_for_level = domain.remove_for_level(parent_op_ids=parent_operators)

            is_still_equal = True
            for idx, (sas_action, prev_action) in enumerate(zip(sas_plan.actions, prev_plan.actions)):
                new_action_dir = sub_actions_out / ("%d-%s" % (idx, sas_action.name))
                old_action_dir = sub_actions_out / ("%d-%s" % (idx, prev_action.name))
                # if not is_still_equal:
                #     if prev_action_dir.is_dir():
                #         shutil.rmtree(prev_action_dir)
                #     continue

                if sas_action != prev_action or not is_still_equal:
                    is_still_equal = (
                        False  # if a action is different, we have to rerun motion validation for that and all future
                    )
                    for d in new_action_dir.rglob("*.hash"):
                        d.unlink()
                    for d in old_action_dir.rglob("*.hash"):
                        d.unlink()

                if sas_action.name != prev_action.name:
                    # different action name -> just delete it and future
                    is_still_equal = False
                    if old_action_dir.is_dir():
                        logger.info("Action %s changed. Delete this and future" % prev_action)
                        shutil.rmtree(old_action_dir)

                    # check if action in previous plan so we can copy it to the right directory
                    if sas_action in prev_plan.actions:
                        old_idx = prev_plan.actions.index(sas_action)
                        old_action_dir = sub_actions_out_cache / ("%d-%s" % (old_idx, sas_action.name))
                        if old_action_dir.is_dir():
                            logger.info("Action %s changed. Copy from previous plan" % sas_action)
                            if new_action_dir.is_dir():
                                shutil.rmtree(new_action_dir)
                            shutil.copytree(old_action_dir, new_action_dir)
                        else:
                            continue
                    else:
                        continue

                prev_action_def = prev_domain_for_level.get_operator(pddl_name=sas_action.name)
                curr_action_def = domain_for_level.get_operator(pddl_name=sas_action.name)

                if prev_action_def != curr_action_def:
                    # different action definition, e.g. effects or preconditions -> replan current and future
                    # is_still_equal = False # composite -> we don't need to replan, even if the params change
                    if (
                        prev_action_def.param_names() != curr_action_def.param_names()
                        and not (new_action_dir / "sub-actions").is_dir()
                    ):
                        logger.info("Action %s changed. Set replan flag for this and future" % prev_action.name)
                        replans[idx] = True
                    else:
                        # TODO: this is not 100% correct, since the old definition is still in the chat history, but
                        # its required since goal overshoot would trigger replanning from scratch, risking different decomposition of action
                        logger.info("Action %s changed. Same # of parameters, so no replan" % prev_action.name)
                    # replans[idx:] = [True for _ in range(len(replans) - idx)] # TODO: this is correct if the state for subsequent also changed
                    # shutil.rmtree(prev_action_dir)
                    continue

            shutil.rmtree(sub_actions_out_cache)

        sub_actions_out.mkdir(exist_ok=True)
        prev_domain_file.write_text(domain.dumps())
        prev_sas_plan_file.write_text(sas_plan.to_string())

        curr_state = problem.initial_state
        assert curr_state is not None

        action_subdirs = [
            Path(sub_actions_out / ("%d-%s" % (i, action.name))) for i, action in enumerate(sas_plan.actions)
        ]

        current_level_domain = domain.remove_for_level(parent_op_ids=parent_operators)
        assert current_level_domain.has_unique_names()

        for sas_action, action_subdir, replan in zip(sas_plan.actions, action_subdirs, replans):
            logger.info("-" * 100)
            logger.info("Plan next action in hierarchy: %s" % sas_action.to_string())
            subplan_result.add_action(sas_action=sas_action)

            # TODO: better way?! But needed to avoid circular import
            from tp_lodge.task_planning.pddl_planner.hi_planner.pddl_action_node_wrapper import (
                PDDLActionNodeWrapper,
            )

            operator = current_level_domain.get_operator(pddl_name=sas_action.name)

            child_node = PDDLActionNodeWrapper(
                operator_id=operator.id,
                parent_operators=parent_operators,
                storage=self.storage,
                motion_validator=self.motion_validator,
                textgen_api=self.textgen_api,
                out_dir=action_subdir,
                sas_action=sas_action,
                domain=domain,
                problem=problem,
                plan_result=plan_result,
                parent_chat=parent_chat,
                subplan_result=subplan_result,
                curr_state=curr_state,
                env_hash=env_hash,
            )

            try:
                action_result, env_hash = child_node.infer(replan=replan)
            except MotionSimulationException as e:
                raise e.copy_with(
                    message="Error in sub-action %s: %s" % (sas_action.to_string(), e.message),
                )
            subplan_result.specify_action(sas_action=sas_action, result=action_result)

            # get sub-problem for this sas_action # TODO: also computed in `infer`
            state_problem = AIStateTransitionRetriever().retrieve_single(
                domain=current_level_domain, objects=problem.objects, current_state=curr_state, action=sas_action.to_string()
            )
            curr_state = state_problem.goal_state_list

        return env_hash, subplan_result
