import re
from typing import List, Optional, Tuple

from pddl.core import Action, Constant, Predicate
from pddl.custom_types import _is_a_keyword
from pddl.logic import Variable
from pddl.logic.base import And, ExistsCondition, ForallCondition, Imply, Not, Or
from pddl.logic.effects import When
from pddl.logic.predicates import EqualTo
from python_utils.string_utils import remove_comments

from tp_lodge.utils.pdd_val_utils import validate_global_variables_in_set

name_rgx = r"[a-zA-Z0-9_-]+"


def _until_next_closing_parenthesis(s: str) -> Tuple[str, str]:
    """
    This function finds the next closing parenthesis in a string and returns the substring up to that point
    and the remaining string.
    """
    if s[0] != "(": raise ValueError("The string must start with an opening parenthesis")
    n_closing_parenthesis_to_skip = 0
    for next_parenthesis in re.finditer(r"[\(\)]", s[1:]):
        if next_parenthesis.group() == "(":
            n_closing_parenthesis_to_skip += 1
        elif next_parenthesis.group() == ")":
            if n_closing_parenthesis_to_skip == 0:
                before_closing = s[: next_parenthesis.end() + 1]
                after_closing = s[next_parenthesis.end() + 1 :]
                return before_closing, after_closing
            else:
                n_closing_parenthesis_to_skip -= 1
    raise ValueError("No closing parenthesis found in the string `%s`" % s)


def _parentheses_groups(s: str):
    assert s[0] == "(" and s[-1] == ")", "The string must start and end with parentheses"
    while s != "":
        next_group, s = _until_next_closing_parenthesis(s.strip())
        yield next_group


def parse_variable(variables_str: str, variable_type: Optional[str] = None):
    if " - " in variables_str:
        variables_str, variable_type = variables_str.strip().split(" - ")
    variable_name = re.match(rf"\?({name_rgx})", variables_str).group(1)
    if variable_type == 'object':
        raise ValueError("Type 'object' is not allowed for variables, as it is the super-type of all types.")
    variable_type = [variable_type] if variable_type is not None else None
    if _is_a_keyword(variable_name):
        raise ValueError(f"Syntax error: {variable_name} is a keyword and cannot be used as a variable")
    return Variable(variable_name, variable_type)


def parse_variable_definitions(variables_str: str):
    variables = []
    for var_set in re.findall(rf"(?:\?.+? +)+\- +[\w\W]+?(?=(?:\?|$))", variables_str):
        variable_str, var_type = var_set.strip().split(" - ")
        variable_str = variable_str.strip().split()
        for variable_str in variable_str:
            variables.append(parse_variable(variable_str, var_type))
    n_vars = len(re.findall(rf"\?{name_rgx}", variables_str))
    if len(variables) != n_vars:
        raise ValueError(f"Syntax error: All defined variables must have a type. Found {n_vars} variables, but only {len(variables)} have types.")
    return variables


def parse_constant(constant_str: str):
    constant_str = constant_str.strip()
    constant_type = None
    if " - " in constant_str:
        constant_str, constant_type = constant_str.split(" - ")
    if _is_a_keyword(constant_str):
        raise ValueError(f"Syntax error: {constant_str} is a keyword and cannot be used as a constant")
    return Constant(constant_str, constant_type)


def parse_predicate(
    predicate_str: str, *, only_variables: bool = True, known_predicates: Optional[List[Predicate]] = None
) -> Predicate:
    assert predicate_str[0] == "(" and predicate_str[-1] == ")", "The predicate must start and end with parentheses"

    matches = re.match(rf"\(({name_rgx}) *(?: +([\w\W]+))?\)", predicate_str)
    if matches is None:
        raise ValueError(
            "Syntax error: Invalid predicate definition %s (expecting ({pred_name} {pred_args..}))" % predicate_str
        )
    predicate_name = matches.group(1)
    predicate_args = matches.group(2)
    if predicate_args is None:
        predicate_args = []
    else:
        predicate_args = re.findall(rf"\??{name_rgx}(?: \- {name_rgx})?", predicate_args)
        if only_variables and any(arg[0] != "?" for arg in predicate_args):
            raise ValueError(
                f"Syntax error: Predicate arguments of {predicate_name} must be variables. Found: {predicate_args}"
            )

    if _is_a_keyword(predicate_name):
        raise ValueError(f"Syntax error: {predicate_name} is a keyword and cannot be used as a predicate name")

    new_predicate = Predicate(
        predicate_name,
        *list(
            map(
                lambda p: parse_term(p, only_variables=only_variables, known_predicates=known_predicates),
                predicate_args,
            )
        ),
    )

    if known_predicates is not None:
        # check if the predicate is already known
        for known_predicate in known_predicates:
            if new_predicate.name == known_predicate.name:
                if len(new_predicate.args) != len(known_predicate.args):
                    raise ValueError(
                        f"Syntax error: Predicate {predicate_name} must have {len(known_predicate.args)} arguments. Found: {len(new_predicate.args)}"
                    )

                # check if the types of the arguments are correct
                for new_arg, known_arg in zip(new_predicate.terms, known_predicate.terms):
                    if new_arg.type_tags[0] != known_arg.type_tags[0]:
                        raise ValueError(
                            f"Syntax error: Predicate {predicate_name} must have argument {known_arg.name} of type {known_arg.type_tags[0]}. Found: {new_arg.type_tags[0]}"
                        )

    return new_predicate


def parse_formula(
    formula_str: str,
    only_variables: bool = True,
    *,
    known_predicates: Optional[List[Predicate]] = None,
    unsupported_formulas: Optional[List[str]] = None,
):
    assert formula_str[0] == "(" and formula_str[-1] == ")", "The formula must start and end with parentheses"
    formula_str = remove_comments(formula_str)

    matches = re.match(rf"\(([a-zA-Z0-9_\-\=]+)(?:\s+([\w\W]+))?\)", formula_str)
    if matches is None:
        raise ValueError(
            "Syntax error: Invalid formula definition %s (expecting ({formula_name} {formula_args..}))" % formula_str
        )
    formula_name = matches.group(1)
    formula_content = matches.group(2)

    if unsupported_formulas is not None and formula_name in unsupported_formulas:
        raise ValueError(f"Syntax error: Formula `{formula_name}` is not supported in the current context")

    if _is_a_keyword(formula_name):
        # must be a formula
        if formula_name in ["exists", "forall"]:
            variables_str, conditions_str = _parentheses_groups(formula_content.strip())
            # TOOD:  parse variables
            variables = parse_variable_definitions(variables_str[1:-1])
            conditions = parse_formula(conditions_str, only_variables=True)
            if formula_name == "forall":
                return ForallCondition(conditions, variables)
            elif formula_name == "exists":
                return ExistsCondition(conditions, variables)
        elif formula_name == "=":
            terms = list(map(parse_term, formula_content.split()))
            return EqualTo(*terms)

        try:
            terms = list(
                map(
                    lambda t: parse_term(t, only_variables=only_variables), _parentheses_groups(formula_content.strip())
                )
            )
        except AssertionError:
            raise ValueError(f"Syntax error: {formula_content} is not a valid formula")
        if formula_name == "and":
            return And(*terms)
        elif formula_name == "or":
            return Or(*terms)
        elif formula_name == "not":
            if len(terms) != 1:
                raise ValueError(f"Syntax error: Not operator must have one argument: {formula_str}")
            term = terms[0]
            # if isinstance(term, ExistsCondition):
            #     raise ValueError(f"Syntax error: Not operator must not be used with quantifiers: {formula_str}")
            if isinstance(term, Not):
                return term.argument  # double negation, return the term itself
            return Not(term)
        elif formula_name == "when":
            return When(condition=terms[0], effect=terms[1])
        elif formula_name == "imply":
            return Imply(*terms)
        elif formula_name in ["if", "implies"]:
            raise ValueError("invalid formula name `%s` in `%s`" % (formula_name, formula_str))
        else:
            raise ValueError("invalid formula name `%s` in `%s`" % (formula_name, formula_str))
            # raise NotImplementedError(str(formula_str))
    else:
        # predicate
        return parse_predicate(formula_str, only_variables=only_variables, known_predicates=known_predicates)


def parse_term(
    term_str: str,
    only_variables: bool = True,
    *,
    known_predicates: Optional[List[Predicate]] = None,
    unsupported_formulas: Optional[List[str]] = None,
):
    if re.match(r"\(\s*(?:and)?\s*\)", term_str):
        return And()
    elif term_str[0] == "(":
        # must be a formula
        return parse_formula(
            term_str,
            only_variables=only_variables,
            known_predicates=known_predicates,
            unsupported_formulas=unsupported_formulas,
        )
    elif term_str[0] == "?":
        # variable
        return parse_variable(term_str)
    else:
        if only_variables:
            raise ValueError(
                f"Syntax error: {term_str} is not a valid variable. Constants are not allowed in this context"
            )
        return parse_constant(term_str)


def parse_action(
    action_str: str, *, previous_action: Optional[Action] = None, known_predicates: Optional[List[Predicate]] = None
):
    action_str = remove_comments(action_str)
    matches = re.match(rf"\(:action ({name_rgx})([\w\W]+)\)", action_str.strip())
    if matches is None:
        if previous_action is None:
            raise ValueError(f"Syntax error: {action_str} is not a valid action definition")
        else:
            action_name = previous_action.name
            after_group = action_str
    else:
        action_name = matches.group(1)
        after_group = matches.group(2)

    if previous_action is not None:
        parameters = previous_action.parameters
        precondition = previous_action.precondition
        effect = previous_action.effect
    else:
        parameters = None
        precondition = None
        effect = None

    while after_group != "":
        matches = re.match(rf"(:{name_rgx})([\w\W]+)", after_group.strip())

        if matches is None:
            raise ValueError(
                "Syntax error: Invalid pddl action definition for action %s:\n`%s`" % (action_name, after_group.strip())
            )

        var_type = matches.group(1)
        var_content = matches.group(2)
        var_content, after_group = _until_next_closing_parenthesis(var_content.strip())
        var_content = var_content.strip()

        supported_var_types = [":parameters", ":precondition", ":effect"]

        if var_type == ":parameters":
            try:
                parameters = parse_variable_definitions(var_content[1:-1])
            except ValueError as e:
                raise ValueError("Parsing parameters in `%s`: %s" % (action_name, str(e)))
        elif var_type == ":precondition":
            try:
                precondition = parse_term(var_content, known_predicates=known_predicates)
            except ValueError as e:
                raise ValueError("Parsing precondition in `%s`: %s" % (action_name, str(e)))
        elif var_type == ":effect":
            try:
                effect = parse_term(
                    var_content, known_predicates=known_predicates, unsupported_formulas=["or", "exists"]
                )
            except ValueError as e:
                raise ValueError("Parsing effect in `%s`: %s" % (action_name, str(e)))
        else:
            raise ValueError(
                f"Syntax error: {var_type} is an invalid type for an action definition (supported are: %s). Correct the syntax of action %s"
                % (", ".join(supported_var_types), action_name)
            )

    if parameters is None:
        raise ValueError(f"Syntax error: Action {action_name} must have parameters")
    if precondition is None:
        raise ValueError(f"Syntax error: Action {action_name} must have a precondition")
    if effect is None:
        raise ValueError(f"Syntax error: Action {action_name} must have an effect")

    validate_global_variables_in_set(precondition, parameters)
    validate_global_variables_in_set(effect, parameters)

    action = Action(
        name=action_name,
        parameters=parameters,
        precondition=precondition,
        effect=effect,
    )

    return action


def parse_domain(domain_str: str):
    domain_str = "\n".join(filter(lambda x: not x.strip().startswith(";;"), domain_str.split("\n")))
    domain_content = re.match(r"\(define\s+\(domain\s+\w+\)([\w\W]+)\)", domain_str).group(1).strip()

    after_group = domain_content
    while after_group != "":
        next_group, after_group = _until_next_closing_parenthesis(after_group.strip())

        next_group_type = re.match(r"\((:\w+)", next_group).group(1)

        if next_group_type == ":requirements":
            pass
        elif next_group_type == ":types":
            # no type checking
            pass
        elif next_group_type == ":predicates":
            # no type checking
            pass
        elif next_group_type == ":action":
            parse_action(next_group)
        elif next_group_type == ":constants":
            raise ValueError(
                f"Syntax error: Global constants are not supported in the domain definition. Rather use variables."
            )
        else:
            raise ValueError(f"Syntax error: type {next_group_type} is not supported in the domain definition")
