from dataclasses import dataclass
from typing import Dict, List, Optional, Union, Literal

from pddl.core import Constant, Predicate
from python_utils.string_utils import xml_escape

from tp_lodge.task_planning.models.pddl.pddl_object import PDDLObject
from tp_lodge.utils.pddl_domain_syntax import parse_predicate


@dataclass
class PDDLPredicate(PDDLObject):
    definition: Predicate
    description: str
    pred_type: Union[Literal["state"], Literal["other"]]  # "state" or "other"
    predefined: bool = False
    newly_generated: bool = False

    parent_operator_id: str = "root"

    def __post_init__(self):
        assert all(len(list(t.type_tags)) > 0 for t in self.definition.terms), (
            "Predicate %s has no type tags" % self.definition.name
        )
        self.description = xml_escape(self.description)
        allowed_pred_types = ["state", "other"]
        assert self.pred_type in allowed_pred_types, (
            "Predicate %s has invalid pred_type: %s. Must be one of: %s" % (self.definition.name, self.pred_type, allowed_pred_types)
        )

    @property
    def is_visual(self) -> bool:
        return self.pred_type == "state" and not self.predefined

    @property
    def name(self) -> str:
        return self.definition.name

    def definition_str(self) -> str:
        args = ["?%s - %s" % (t.name, list(t.type_tags)[0]) for t in self.definition.terms]
        return "(%s %s)" % (self.definition.name, " ".join(args))

    def __str__(self):
        return "%s: %s" % (self.definition_str(), self.description)

    def to_pddl(self) -> Predicate:
        return self.definition

    @classmethod
    def from_json(cls, data: Dict) -> "PDDLPredicate":
        definition = parse_predicate(data["definition"])
        description = data["description"]
        newly_generated = data.get("newly_generated", False)
        predefined = data.get("predefined", False)
        pred_type = data.get("pred_type", "other")
        parent_operator_id = data.get("parent_operator_id", "root")
        return cls(
            definition=definition,
            description=description,
            newly_generated=newly_generated,
            predefined=predefined,
            parent_operator_id=parent_operator_id,
            pred_type=pred_type,
        )

    def to_json(self) -> Dict:
        return {
            "definition": self.definition_str(),
            "pred_type": self.pred_type,
            "description": self.description,
            "newly_generated": self.newly_generated,
            "predefined": self.predefined,
            "parent_operator_id": self.parent_operator_id,
        }

    def update_inplace(self, definition: Optional[Predicate] = None, description: Optional[str] = None, parent_operator_id: Optional[str] = None) -> None:
        if definition is not None and self.definition != definition:
            self.definition = definition
        if description is not None and self.description != description:
            self.description = description
        if parent_operator_id is not None and self.parent_operator_id != parent_operator_id:
            self.parent_operator_id = parent_operator_id

    def copy_with(
        self,
        predefined: Optional[bool] = None,
        pred_type: Optional[str] = None,
        definition: Optional[Predicate] = None,
        description: Optional[str] = None,
        newly_generated: Optional[bool] = None,
        parent_operator_id: Optional[str] = None,
    ) -> "PDDLPredicate":
        return PDDLPredicate(
            pred_type=pred_type if pred_type is not None else self.pred_type,
            definition=definition if definition is not None else self.definition,
            description=description if description is not None else self.description,
            newly_generated=newly_generated if newly_generated is not None else self.newly_generated,
            predefined=predefined if predefined is not None else self.predefined,
            parent_operator_id=(
                parent_operator_id if parent_operator_id is not None else self.parent_operator_id
            ),
        )

    def verify(self, predicate: Predicate, objects: List[Constant], type_hierarchy: Dict[str, List[str]]) -> None:
        """Verify that a predicate instance matches this predicate definition.

        Args:
            predicate: The predicate instance to verify
            objects: List of available objects for type checking
            type_hierarchy: Dictionary mapping type names to their parent types (including themselves)
        """
        from tp_lodge.utils.pddl_verify_utils import verify_predicate

        verify_predicate(
            g_predicate=predicate, l_predicate=self.definition, objects=objects, type_hierarchy=type_hierarchy
        )
