from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

from pddl.core import Constant, Domain, Predicate, Requirements

from tp_lodge.task_planning.models.pddl.pddl_object import PDDLObject
from tp_lodge.task_planning.models.pddl.pddl_operation import PDDLOperation
from tp_lodge.task_planning.models.pddl.pddl_operator import PDDLOperator
from tp_lodge.task_planning.models.pddl.pddl_predicate import PDDLPredicate
from tp_lodge.utils.pddl_utils import get_predicates_used_in_formula


@dataclass
class PDDLDomain(PDDLObject):
    operators: List[PDDLOperator]
    predicates: List[PDDLPredicate]
    types: Dict[str, str]

    def copy_with(
        self,
        operators: Optional[List[PDDLOperator]] = None,
        predicates: Optional[List[PDDLPredicate]] = None,
        types: Optional[Dict[str, str]] = None,
    ) -> "PDDLDomain":
        return PDDLDomain(
            operators=operators if operators is not None else self.operators,
            predicates=predicates if predicates is not None else self.predicates,
            types=types if types is not None else self.types,
        )

    @classmethod
    def from_json(cls, data: Dict) -> "PDDLDomain":
        types = data.get("types", {})
        predicates = [PDDLPredicate.from_json(p) for p in data.get("predicates", [])]
        if "actions" in data:
            operators = [PDDLOperator.from_json(a) for a in data.get("actions", [])]
        else:
            operators = [PDDLOperator.from_json(a) for a in data.get("operators", [])]

        return cls(operators=operators, predicates=predicates, types=types)

    def to_json(self) -> Dict:
        return {
            "operators": [operator.to_json() for operator in self.operators],
            "predicates": [predicate.to_json() for predicate in self.predicates],
            "types": self.types,
        }

    def child_types(self, type: str) -> List[str]:
        """Get all child types of a given type."""
        assert type in self.types

        child_types = []
        for t in self.types.keys():
            p_types = self.parent_types(t)
            if type in p_types:
                child_types.append(t)
        return child_types

    def parent_types(self, type: str) -> List[str]:
        assert type in self.types

        p_type = self.types[type]

        all_types = [type]
        # if p_type != "object":
        if p_type in self.types:
            p_types = self.parent_types(p_type)
            all_types.extend(p_types)
        return all_types

    def get_types_by_parents(self) -> Dict[str, List[str]]:
        all_types = defaultdict(list)
        for type, p_type in self.types.items():
            all_types[p_type].append(type)
        return all_types

    def get_type_hierarchy(self) -> Dict[str, List[str]]:
        """Get a dictionary mapping each type to its parent types (including itself)."""
        type_hierarchy = {type_name: self.parent_types(type_name) for type_name in self.types.keys()}
        return type_hierarchy

    def has_unique_names(self) -> bool:
        op_names = [op.definition.name for op in self.operators]
        if len(op_names) != len(set(op_names)):
            return False

        pred_names = [p.definition.name for p in self.predicates]
        if len(pred_names) != len(set(pred_names)):
            return False

        return True

    def mark_not_newly_generated(self):
        for op in self.operators:
            op.last_operation = PDDLOperation.NONE
        for p in self.predicates:
            p.newly_generated = False

    def verify_predicate(self, predicate: Predicate, objects: List[Constant]):
        matching_predicate_def = next(filter(lambda p: p.name == predicate.name, self.predicates), None)
        if matching_predicate_def is None:
            raise ValueError(f"Predicate {predicate.name} not found in domain.")

        type_hierarchy = self.get_type_hierarchy()
        matching_predicate_def.verify(predicate, objects, type_hierarchy)

    @property
    def fully_defined(self) -> bool:
        return len(self.operators) > 0

    def get_operator_by_id(self, op_id: str) -> PDDLOperator:
        for operator in self.operators:
            if operator.id == op_id:
                return operator
        raise KeyError(f"Operator with id {op_id} not found in domain.")

    def get_operator(self, pddl_name: str) -> PDDLOperator:
        operators = [op for op in self.operators if op.definition.name.lower() == pddl_name]
        assert len(operators) <= 1, f"Multiple operators with name {pddl_name} found in domain."
        if len(operators) == 1:
            return operators[0]
        raise KeyError(f"Operator {pddl_name} not found in domain.")

    def get_predicate(self, pddl_name: str) -> PDDLPredicate:
        for predicate in self.predicates:
            if predicate.name.lower() == pddl_name:
                return predicate
        raise KeyError(f"Predicate {pddl_name} not found in domain.")

    def get_parent_operator_ids(self, parent_op_id: str) -> List[str]:
        op_names = [parent_op_id]
        while True:
            try:
                parent_op = self.get_operator_by_id(parent_op_id)
                parent_op_id = parent_op.parent_operator_id
                op_names.append(parent_op_id)

            except KeyError:
                break

        return op_names

    def remove_for_level(self, parent_op_ids: List[str]) -> "PDDLDomain":
        """Get a domain for a specific level of the hierarchy."""
        # operators = [op for op in self.operators if op.name not in parent_op_names]
        direct_parent = parent_op_ids[-1]  # only keep the lowest level
        operators = [op for op in self.operators if op.id != direct_parent and op.parent_operator_id == direct_parent]
        sorted_operators: dict[str, PDDLOperator] = {}
        for parent_op_id in reversed([direct_parent]):
            # start with lowest parent operator id, go upwards
            # only add operator if one with same name does not exist yet
            for op in operators:
                if op.parent_operator_id == parent_op_id and op.name not in sorted_operators:
                    sorted_operators[op.name] = op
        operators = list(sorted_operators.values())
        # keep predicates of all parents, fix the ones that are not created on this level
        predicates = [
            pred  # .copy_with(predefined=pred.parent_operator_id != direct_parent or pred.predefined)
            for pred in self.predicates
            if pred.parent_operator_id in parent_op_ids
        ]
        return self.copy_with(operators=operators, predicates=predicates, types=self.types)

    def add_child_domain(self, child_domain: "PDDLDomain", parent_operator_id: str):
        """Add a child domain to the current domain."""
        parent_op_ids = self.get_parent_operator_ids(parent_operator_id)
        for predicate in child_domain.predicates:
            try:
                prev_predicate = self.get_predicate(predicate.name)
                if prev_predicate.parent_operator_id == parent_operator_id:
                    # replace the old predicate with the new one
                    prev_predicate.update_inplace(definition=predicate.definition, description=predicate.description)
                else:
                    # don't update it
                    if prev_predicate.parent_operator_id not in parent_op_ids:
                        # predicate was not visible -> lift it
                        prev_parent_op_ids = self.get_parent_operator_ids(prev_predicate.parent_operator_id)
                        assert prev_predicate.parent_operator_id == prev_parent_op_ids[0]
                        assert any(p_id in parent_op_ids for p_id in prev_parent_op_ids)
                        for parent_id in prev_parent_op_ids[1:]:
                            if parent_id in parent_op_ids:
                                # we don't want to change the definition, since the already existing operator has been used most probably -> is more accurate
                                prev_predicate.update_inplace(parent_operator_id=parent_id)
                                break
                        # self.predicates.append(predicate.copy_with(parent_operator_id=parent_operator_id))
                    # assert (
                    #     prev_predicate.parent_operator_name in parent_op_names
                    # ), f"Predicate {prev_predicate.name} has a different parent operator: {prev_predicate.parent_operator_name} vs {parent_operator_name}"
            except KeyError:
                self.predicates.append(predicate.copy_with(parent_operator_id=parent_operator_id))

        for predicate in self.predicates:
            if predicate.parent_operator_id != parent_operator_id:
                continue

            # check if child still has the predicate, otherwise remove it
            try:
                child_domain.get_predicate(predicate.name)
                continue
            except KeyError:
                self.predicates.remove(predicate)

        # add and update operators
        for operator in child_domain.operators:
            try:
                prev_operator = self.get_operator_by_id(operator.id)
                # replace the old operator with the new one, keep the parent operator name
                if prev_operator.parent_operator_id in parent_op_ids:
                    # `operator` was visible
                    prev_operator.update_inplace(definition=operator.definition, description=operator.description)
                else:
                    try:
                        ops_w_same_name = [op for op in self.operators if op.name == operator.name]
                        for op_w_same_name in ops_w_same_name:
                            op_parent_ids = self.get_parent_operator_ids(op_w_same_name.parent_operator_id)
                            if parent_operator_id in op_parent_ids and operator.id not in op_parent_ids:
                                # operator in a child domain exists that will see this new operator
                                # remove that operator
                                del self.operators[self.operators.index(op_w_same_name)]
                    except KeyError:
                        pass
                    self.operators.append(operator.copy_with(parent_operator_id=parent_operator_id))
                    # we invented the `operator` again. We lift its visibility to the same parent
                    # prev_operator_parents = self.get_parent_operator_names(prev_operator.parent_operator_name)
                    # assert prev_operator.parent_operator_name == prev_operator_parents[0]
                    # assert any(p_name in parent_op_names for p_name in prev_operator_parents)
                    # for parent_name in prev_operator_parents[1:]:
                    #     if parent_name in parent_op_names:
                    #         prev_operator.update_inplace(
                    #             # we don't want to change the definition, since the already existing operator has been used most probably -> is more accurate
                    #             # definition=operator.definition,
                    #             # description=operator.description,
                    #             parent_operator_name=parent_name,
                    #         )
                    #         break
            except KeyError:
                self.operators.append(operator.copy_with(parent_operator_id=parent_operator_id))

        # remove operators
        for operator in self.operators:
            if operator.parent_operator_id != parent_operator_id:
                continue

            # check if child still has the operator, otherwise remove it
            try:
                child_domain.get_operator_by_id(operator.id)
                continue
            except KeyError:
                self.operators.remove(operator)

    def update_with(self, other: "PDDLDomain"):
        assert self.types == other.types, "Types must match for update"
        for operator in self.operators:
            try:
                other_op = other.get_operator_by_id(operator.id)
                operator.update_inplace(definition=other_op.definition, description=other_op.description)
            except KeyError:
                # this can happen if we in a child added a operator that in `other` did not exist back then
                # we don't do anything in this case
                pass

        for operator in other.operators:
            if operator.id not in [op.id for op in self.operators]:
                self.operators.append(operator)

        self.predicates = other.predicates

    def to_pddl(self, domain_name: str = "ai_domain") -> Domain:
        domain = Domain(
            name=domain_name,
            types=self.types,
            requirements=Requirements.adl_requirements(),
            predicates=[predicate.definition for predicate in self.predicates],
            actions=[operator.definition for operator in self.operators],
        )

        return domain

    def print_operator_hierarchy(self):
        """Print the operator hierarchy."""
        ops_by_parent: Dict[str, List[PDDLOperator]] = defaultdict(list)
        for operator in self.operators:
            ops_by_parent[operator.parent_operator_id].append(operator)

        def print_ops(parent_id: str, level: int = 0):
            indent = "  " * level
            if parent_id in ops_by_parent:
                for op in ops_by_parent[parent_id]:
                    print(f"{indent}- {op.name} (v: {op.verified}, skill: {','.join(op.mapped_skill_sequence)})")
                    print_ops(op.name, level + 1)

        print_ops("root")

    def only_verified_operators(self, remove_unused_predicates: bool = True) -> "PDDLDomain":
        """Return a new domain with only verified operators."""
        verified_operators = [op for op in self.operators if op.verified]
        ps_all = set()
        for op in verified_operators:
            ps = get_predicates_used_in_formula(op.definition.precondition) + get_predicates_used_in_formula(
                op.definition.effect
            )
            ps_all.update([p.name for p in ps])
        verified_predicates = [p for p in self.predicates if p.name in ps_all or p.predefined]
        return self.copy_with(operators=verified_operators, predicates=verified_predicates if remove_unused_predicates else self.predicates)

    def get_change_to(
        self, other: "PDDLDomain"
    ) -> Tuple[Dict[PDDLOperation, List[PDDLOperator]], Dict[PDDLOperation, List[PDDLPredicate]]]:
        op_data = defaultdict(list)
        for operator in self.operators:
            try:
                other_op = other.get_operator_by_id(operator.id)
                if operator != other_op:
                    op_data[PDDLOperation.EDIT].append(operator)
            except KeyError:
                op_data[PDDLOperation.REMOVE].append(operator)

        for operator in other.operators:
            try:
                self.get_operator_by_id(operator.id)
            except KeyError:
                op_data[PDDLOperation.ADD].append(operator)

        pred_data = defaultdict(list)
        for predicate in self.predicates:
            try:
                other_pred = other.get_predicate(predicate.name)
                if predicate != other_pred:
                    pred_data[PDDLOperation.EDIT].append(predicate)
            except KeyError:
                pred_data[PDDLOperation.REMOVE].append(predicate)

        for predicate in other.predicates:
            try:
                self.get_predicate(predicate.name)
            except KeyError:
                pred_data[PDDLOperation.ADD].append(predicate)

        return op_data, pred_data
