from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional

from pddl.core import And, Constant, Formula, Predicate, Problem
from pddl.logic.base import Not

from tp_lodge.task_planning.models.pddl.pddl_object import PDDLObject
from tp_lodge.utils.pddl_domain_syntax import parse_constant, parse_formula


@dataclass
class PDDLProblem(PDDLObject):
    objects: List[Constant]
    grounder_initial_state: Optional[List[Predicate]] = None
    initial_state: Optional[List[Predicate]] = None
    goal_state: Optional[Formula] = None

    def __post_init__(self):
        if self.objects is not None:
            assert all(o.type_tag is not None for o in self.objects), "All objects must have at least one type tag."
        if self.grounder_initial_state is not None:
            if isinstance(self.grounder_initial_state, And):
                self.grounder_initial_state = self.grounder_initial_state.operands
            elif isinstance(self.grounder_initial_state, Predicate):
                self.grounder_initial_state = [self.grounder_initial_state]
        if self.initial_state is not None:
            if isinstance(self.initial_state, And):
                self.initial_state = self.initial_state.operands
            elif isinstance(self.initial_state, Predicate):
                self.initial_state = [self.initial_state]

    def copy_with(
        self,
        objects: Optional[List[Constant]] = None,
        grounder_initial_state: Optional[List[Predicate]] = None,
        initial_state: Optional[List[Predicate]] = None,
        goal_state: Optional[List[Predicate]] = None,
    ) -> "PDDLProblem":
        return PDDLProblem(
            objects=objects or self.objects,
            grounder_initial_state=grounder_initial_state if grounder_initial_state is not None else self.grounder_initial_state,
            initial_state=initial_state if initial_state is not None else self.initial_state,
            goal_state=goal_state if goal_state is not None else self.goal_state,
        )

    @property
    def goal_state_list(self) -> List[Formula]:
        if isinstance(self.goal_state, And):
            return self.goal_state.operands
        elif isinstance(self.goal_state, (Predicate, Not)):
            assert not isinstance(self.goal_state, Not) or isinstance(self.goal_state.argument, Predicate)
            return [self.goal_state]
        else:
            raise NotImplementedError("Goal state is not a Predicate or And.")

    @property
    def fully_defined(self) -> bool:
        return self.initial_state is not None and self.goal_state is not None

    def get_objects_str(self) -> str:
        grouped_by_type = defaultdict(list)
        for obj in self.objects:
            grouped_by_type[obj.type_tag].append(obj.name)
        return " ".join([f"{' '.join(obj_names)} - {obj_type}" for obj_type, obj_names in grouped_by_type.items()])

    def to_json(self) -> dict:
        return {
            "objects": [("%s - %s" % (obj.name, obj.type_tag)) for obj in self.objects],
            "initial_state": [str(pred) for pred in self.initial_state] if self.initial_state else None,
            "goal_state": str(self.goal_state) if self.goal_state else None,
        }

    @classmethod
    def from_json(cls, data: dict) -> "PDDLProblem":
        objects = [parse_constant(obj) for obj in data["objects"]]
        initial_state = (
            [parse_formula(pred, only_variables=False) for pred in data["initial_state"]]
            if data["initial_state"]
            else None
        )
        goal_state = parse_formula(data["goal_state"], only_variables=False) if data["goal_state"] else None
        return PDDLProblem(objects=objects, initial_state=initial_state, goal_state=goal_state)

    def to_pddl(self, domain_name: str = "ai_domain", force: bool = False) -> Problem:
        assert force or self.fully_defined
        problem = Problem(
            domain_name=domain_name,
            name="ai_problem",
            objects=self.objects,
            init=self.initial_state,
            goal=self.goal_state,
        )
        return problem
