# Copyright (c) 2024-present, Royal Bank of Canada.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# copied from https://github.com/BorealisAI/llm-pddl-planning/blob/main/src/domains.py

from typing import Callable, Dict, List, Optional

import fast_downward
import numpy as np
from fast_downward import Atom, Operator
from pddl.core import And
from pddl.parser.problem import ProblemParser

from tp_lodge.utils.planning_cache_utils import SasAction


def get_problem_pddl_empty_goal(problem_pddl: str):
    problem_parsed = ProblemParser()(problem_pddl)
    problem_parsed._goal = And()
    return str(problem_parsed)


class RandomWalkEnv:

    def __init__(
        self, domain_pddl: str, problem_pddl: str, function_mapping: Dict[str, Callable[[SasAction], str]]
    ) -> None:
        self.lib = fast_downward.load_lib()
        self.function_mapping = function_mapping

        problem_pddl = get_problem_pddl_empty_goal(problem_pddl)
        # sas = FastDownwardAIPlanner().translate(domain=domain_pddl, problem=problem_pddl)
        task, sas = fast_downward.pddl2sas(domain_pddl, problem_pddl)
        self.sas = sas.encode("utf-8")

    def __del__(self):
        self.close()

    def close(self):
        pass

    # produces segmentation fault
    #     if hasattr(self, 'lib'):
    #         close_lib(self.lib)
    #         del self.lib

    def _reset_lib(self):
        self.lib.load_sas(self.sas)

    def get_random_walk_plan(self, max_steps: int, *, choose_uniformly_over_type: bool = True):
        self._reset_lib()
        seed = np.random.randint(2**32 - 1)
        rng = np.random.default_rng(seed)

        plan = []
        for _ in range(max_steps):
            available_actions = self._get_applicable_actions()
            available_action_names = list(available_actions.keys())
            if len(available_action_names) == 0:
                break

            # dont use actions that have the same params multiple times. That is hacky, 
            # but fast_downward has some non-deterministic behavior when applying such actions (sometimes ok, sometimes not allowed)
            # e.g. in logistics
            # (:action drive-a-truck-from-one-location-to-another-in-a-city
            #     :parameters (?trk - truck ?from - location ?to - location ?city - city)
            #     :precondition (and (at-truck ?trk ?from) (in-city ?from ?city) (in-city ?to ?city))
            #     :effect (and (at-truck ?trk ?to) (not (at-truck ?trk ?from)))
            # )
            # same args are for some reason not allowed
            available_action_names = [
                name
                for name in available_action_names
                if len(set(map(lambda e: e.strip(), name.split("(")[1].split(")")[0].split(","))))
                == len(name.split("(")[1].split(")")[0].split(","))
            ]
            if len(available_action_names) == 0:
                print(f"Warning: no matching action names found. {available_action_names}")

            if choose_uniformly_over_type:
                available_action_type = list(set(t.split("(")[0] for t in available_action_names))
                action_type = rng.choice(available_action_type)
                matching_action_names = [name for name in available_action_names if name.startswith(action_type)]
                action_name = rng.choice(matching_action_names)
            else:
                action_name = rng.choice(available_action_names)
            plan.append(action_name)
            action = available_actions[action_name]
            _ = self._apply_action(action)
        return plan

    def get_plan(self, plan: List[str]) -> Optional[List[str]]:
        self._reset_lib()
        plan_so_far = []
        for action_name in plan:
            available_actions = self._get_applicable_actions()
            available_action_names = list(available_actions.keys())
            if action_name not in available_action_names:
                return None
            else:
                action = available_actions[action_name]
                _ = self._apply_action(action)
            plan_so_far.append("(%s)" % action.name)

        return plan_so_far

    def get_plan_execution_feedback(self, plan: List[str]):
        return self.get_plan(plan) is not None

    def _get_applicable_actions(self) -> dict:
        operator_count = self.lib.get_applicable_operators_count()
        operators = (Operator * operator_count)()
        self.lib.get_applicable_operators(operators)
        operators = {SasAction.from_string(f"({op.name})"): op for op in operators}
        return {self.function_mapping[k.name](k): op for k, op in operators.items()}

    def _apply_action(self, action):
        effects = (Atom * action.nb_effect_atoms)()
        self.lib.apply_operator(action.id, effects)
        return effects
