import logging
import re
from collections import defaultdict
from typing import List

from pddl.core import Constant, Domain, Formula, Problem
from pddl.logic.base import And, Not
from pddl.logic.predicates import Predicate

from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.task_planning.models.pddl.pddl_problem import PDDLProblem
from tp_lodge.task_planning.pddl_planner.ai_planner.ai_validator import AIValidator
from tp_lodge.utils.pddl_utils import get_effects_from_pred_change

logger = logging.getLogger(__name__)


class AIStateTransitionRetriever:

    def __init__(self):
        self.ai_validator = AIValidator()

    def retrieve_transitions(self, domain: Domain, problem: Problem, actions: str) -> List[List[Formula]]:
        """call val to retrieve the state transitions for every action

        returns for every action the pre and post predicates that hold for that action
        [
            {
                "pre": frozenset(Predicate(at, base-part, init-base-location), ...),
                "post": frozenset(Predicate(at, base-part, robot-location), ...)
            },
            ...
        ]
        """
        assert problem is not None
        response, success = self.ai_validator.validate(
            domain=str(domain), problem=str(problem) if problem is not None else None, plan=actions, options="-v"
        )
        # assert success
        assert "Plan executed successfully" in response, f"VAL failed: {response}"

        pre_step_predicates = defaultdict()
        for predicate in set(problem.init):
            if isinstance(predicate, Not):
                pre_step_predicates[predicate.argument] = False
            elif isinstance(predicate, Predicate):
                pre_step_predicates[predicate] = True
            else:
                raise NotImplementedError()

        action_states = [[pred if valid else Not(pred) for pred, valid in pre_step_predicates.items()]]

        if len(actions.splitlines()) == 0:
            return action_states

        validation_steps = re.findall(r"Plan Validation details([\w\W]+)Plan executed successfully", response)[0]
        # print(validation_steps)
        validation_steps = validation_steps.split("\n\n")
        validation_steps = list(filter(lambda s: s.startswith("Checking next happening"), validation_steps))

        objects_map = {**{o.name: o for o in problem.objects}, **{o.name: o for o in domain.constants}}

        post_step_predicates = pre_step_predicates.copy()
        for validation_step in validation_steps:
            val_step_actions = validation_step.splitlines()[1:]
            for val_step_action in val_step_actions:
                if val_step_action.startswith("WARNING:"):
                    logger.warning("WARNING in VAL response")
                    continue
                    # raise RuntimeError(val_step_action)
                operation, predicate = re.findall(r"(\w+) \((.+)\)", val_step_action)[0]
                predicate_parts = predicate.split(" ")
                predicate_name = predicate_parts[0]
                predicate_terms = [objects_map[object] for object in predicate_parts[1:]]
                predicate = Predicate(predicate_name, *predicate_terms)
                if operation == "Adding":
                    post_step_predicates[predicate] = True
                elif operation == "Deleting":
                    post_step_predicates[predicate] = False
                else:
                    raise NotImplementedError()

            goal = [pred if valid else Not(pred) for pred, valid in post_step_predicates.items()]
            action_states.append(goal)
            pre_step_predicates = post_step_predicates.copy()

        return action_states

    def retrieve_goal_state(self, domain: Domain, problem: Problem, actions: str) -> List[Formula]:
        return self.retrieve_transitions(domain=domain, problem=problem, actions=actions)[-1]

    def retrieve_single(
        self,
        domain: PDDLDomain,
        objects: List[Constant],
        current_state: List[Formula],
        action: str,
        effects_for_goal: bool = False,
    ) -> PDDLProblem:
        problem = PDDLProblem(objects=objects, initial_state=current_state, goal_state=And())
        state_transitions = self.retrieve_transitions(
            domain=domain.to_pddl(), problem=problem.to_pddl(), actions=action
        )
        assert len(state_transitions) == 2

        if effects_for_goal:
            goal_state = get_effects_from_pred_change(state_transitions[0], state_transitions[1])
        else:
            goal_state = And(*state_transitions[1])

        return PDDLProblem(
            objects=objects,
            initial_state=state_transitions[0],
            grounder_initial_state=state_transitions[0],
            goal_state=goal_state,
        )

    def retrieve(self, domain: Domain, problem: Problem, actions: str) -> List[Problem]:
        state_transitions = self.retrieve_transitions(domain=domain, problem=problem, actions=actions)

        problem_defs = []
        for action, i in zip(actions.splitlines(), range(len(state_transitions) - 1)):
            last_state = state_transitions[i]
            next_state = state_transitions[i + 1]

            step_problem = Problem(
                name="problem_state_%d" % i,
                # domain="domain",
                domain_name="fine_domain",
                requirements=problem.requirements,
                objects=problem.objects,
                init=frozenset(last_state),
                goal=And(*frozenset(next_state)),
                metric=problem.metric,
            )

            problem_defs.append(step_problem)

        return problem_defs
