import logging
import re
from typing import Dict, List, Optional, Set

from pddl.core import Action, Domain, Problem
from pddl.logic.base import Formula

from tp_lodge.utils.pddl_utils import get_effects_from_pred_change

logger = logging.getLogger(__name__)


def copy_action_w_args(
    action: Action,
    *,
    name: Optional[str] = None,
    precondition: Optional[Formula] = None,
    effect: Optional[Formula] = None,
):
    return Action(
        name=name or action.name,
        parameters=action.parameters,
        precondition=precondition or action.precondition,
        effect=effect or action.effect,
    )


def copy_domain_w_args(
    domain: Domain,
    *,
    types: Optional[Dict] = None,
    name: Optional[str] = None,
    predicates: Optional[Set] = None,
    actions: Optional[List[Action]] = None,
):
    return Domain(
        name=name or domain.name,
        requirements=domain.requirements,
        types=types if types is not None else domain.types,
        constants=domain.constants,
        predicates=predicates if predicates is not None else domain.predicates,
        derived_predicates=domain.derived_predicates,
        functions=domain.functions,
        actions=actions if actions is not None else domain.actions,
    )


def copy_problem_w_args(
    problem: Problem,
    *,
    domain_name: Optional[str] = None,
    objects: Optional[Set] = None,
    init: Optional[List[Formula]] = None,
    goal: Optional[Formula] = None,
):
    return Problem(
        name=problem.name,
        requirements=problem.requirements,
        domain_name=domain_name if domain_name is not None else problem.domain_name,
        goal=goal if goal is not None else problem.goal,
        init=init if init is not None else problem.init,
        metric=problem.metric,
        objects=objects if objects is not None else problem.objects,
    )


def clean_problem_goal_state(problem: Problem) -> Problem:
    """check what predicates change from init to goal and only add them to the goal (allows overshoot)"""
    init = problem.init
    goal = problem.goal.operands

    effect = get_effects_from_pred_change(prior_predicates=init, post_predicates=goal)

    new_problem = copy_problem_w_args(problem=problem, goal=effect)

    return new_problem


def remove_types_from_domain(domain_str: str) -> str:
    """Remove types from the domain."""
    while True:
        match = re.search(r"(\?\w+) \- \w+", domain_str)
        if match is None:
            break
        domain_str = domain_str[: match.start()] + match.group(1) + domain_str[match.end() :]

    domain_str = re.sub(r"\(:types [\w\W]+?\)*(\((?::constants|:predicates|:functions|:derived))", r"\1", domain_str)

    return domain_str


def check_consistency(domain: Domain, problem: Problem):
    """Check if the domain and problem are consistent."""
    problem.check(domain)

    known_predicates = [p.name for p in domain.predicates]

    from tp_lodge.utils.pddl_utils import get_predicate_evaluation

    predicates = list(get_predicate_evaluation(list(problem.init)).keys())
    unknown_predicates = [p for p in predicates if p.name not in known_predicates]
    if unknown_predicates:
        raise ValueError(
            f"Unknown predicates in problem: {', '.join(str(p) for p in unknown_predicates)}. "
            "Make sure the domain and problem are consistent."
        )
