import logging
from typing import List
from pddl.core import Predicate
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from demos.ipc.src.logistics.logistics_environment import LogisticsEnvironment
from demos.ipc.src.logistics.logistics_motion_simulator import LogisticsMotionSimulator
from tp_lodge.motion_planning.local_motion_validator import LocalMotionValidator


logger = logging.getLogger(__name__)


class LogisticsMotionValidator(LocalMotionValidator):

    def __init__(self, hide_env_feedback: bool):
        super().__init__(env=LogisticsMotionSimulator(env=LogisticsEnvironment()), hide_env_feedback=hide_env_feedback)
        self.init_state(seed=0)

    def get_supported_predicates(self, domain: PDDLDomain) -> List[str]:
        return [
            "in-city",
            "at-truck",
            "truck-in",
            "in-truck",
            "at",
            "at-plane",
            "plane-in",
            "in-plane",
            "is-airport",
        ]

    def validate_predicate(self, pred: Predicate) -> bool:
        args = [arg.name for arg in pred.terms]
        state = self.env.env.state

        if pred.name == "in-city" and len(args) == 2:
            location, city = args
            return location in state.cities.get(city, [])

        elif pred.name == "at-truck" and len(args) == 2:
            truck_id, location = args
            return state.trucks.get(truck_id, {}).get("location") == location

        elif pred.name == "in-truck" and len(args) == 2:
            package_id, truck_id = args
            return package_id in state.trucks.get(truck_id, {}).get("packages", [])

        elif pred.name == "at" and len(args) == 2:
            package_id, location = args
            return state.packages.get(package_id, {}).get("location") == location

        elif pred.name == "at-plane" and len(args) == 2:
            plane_id, location = args
            return state.airplanes.get(plane_id, {}).get("location") == location

        elif pred.name == "in-plane" and len(args) == 2:
            package_id, plane_id = args
            return package_id in state.airplanes.get(plane_id, {}).get("packages", [])

        elif pred.name == "is-airport" and len(args) == 1:
            loc = args[0]
            return loc in state.airports

        else:
            raise ValueError(
                f"Predicate '{pred.name}' with arguments {args} is not supported by the logistics motion validator."
            )
