import copy
import logging
from typing import Dict, List, Optional, Tuple, Union, Set

from pddl.core import Constant, Predicate, Formula
from pddl.logic import Variable
from pddl.logic.base import And, BinaryOp, ExistsCondition, ForallCondition, Formula, Not, UnaryOp
from pddl.logic.effects import When
from pddl.logic.predicates import EqualTo
from python_utils.string_utils import get_markup_from_text

logger = logging.getLogger(__name__)


def pddls_from_text(text: str) -> List[str]:
    return get_markup_from_text(text=text, markup=["pddl", "lisp"])


def flatten_formula_to_predicates(formula: Formula) -> List[Union[Predicate, Not]]:
    if isinstance(formula, BinaryOp):
        flattened = []
        for operand in formula.operands:
            flattened.extend(flatten_formula_to_predicates(operand))
        return flattened
    elif isinstance(formula, Predicate):
        return [formula]
    elif isinstance(formula, UnaryOp):
        assert isinstance(formula, Not), ""
        return [Not(p) for p in flatten_formula_to_predicates(formula.argument)]
    elif isinstance(formula, (List, Tuple)):
        return flatten_formula_to_predicates(And(*formula))
    elif isinstance(formula, EqualTo):
        return []
    elif isinstance(formula, (ExistsCondition, ForallCondition)):
        return flatten_formula_to_predicates(formula.condition)
    elif isinstance(formula, When):
        return flatten_formula_to_predicates(formula.condition) + flatten_formula_to_predicates(formula.effect)
    else:
        raise NotImplementedError("Unsupported formula type: %s" % str(type(formula)))


def get_predicates_used_in_formula(formula: Formula) -> List[Predicate]:
    return [p if not isinstance(p, Not) else p.argument for p in flatten_formula_to_predicates(formula)]


def get_valid_predicates(predicates: List[Union[Predicate, Not]]) -> List[Predicate]:
    return [p for p, p_eval in get_predicate_evaluation(predicates=predicates).items() if p_eval]


def is_predicates_subset(
    predicates: List[Formula], subset_of: List[Formula], *, enforce_parent_set_has_negated: bool = True
) -> bool:
    preds_eval = get_predicate_evaluation(predicates=predicates)
    subsets_of_eval = get_predicate_evaluation(predicates=subset_of)

    for predicate, pred_eval in preds_eval.items():
        subset_of_eval = subsets_of_eval.get(predicate, None)
        # if len(set(predicate.terms)) != len(predicate.terms):
        # duplicate predicate terms -> always false
        # return True if not pred_eval else False
        if subset_of_eval is None:
            # predicate not in subset -> always false
            if not enforce_parent_set_has_negated:
                return False

            if pred_eval:
                # subset pred is True, but not in superset -> false
                # can happen if operator add neg predicates that motion validator does not generate
                return False
            else:
                # predicate not in superset, but since false its ok
                continue
            # raise RuntimeError(
            #     "Predicate %s not in subset %s. This should never happen. Please report this bug."
            #     % (str(predicate), str(subset_of))
            # )

        if pred_eval != subset_of_eval:
            return False
    return True


def get_predicate_evaluation(predicates: list[Formula]) -> dict[Predicate, bool]:
    if isinstance(predicates, And):
        return get_predicate_evaluation(predicates.operands)
    elif isinstance(predicates, (Predicate, Not)):
        return get_predicate_evaluation([predicates])
    pred_eval = {}
    for predicate in predicates:
        if isinstance(predicate, Not):
            assert isinstance(predicate.argument, Predicate)
            if predicate.argument in pred_eval:
                raise RuntimeError("Predicate %s already evaluated." % str(predicate.argument))
            pred_eval[predicate.argument] = False
        else:
            assert isinstance(predicate, Predicate), "Expected predicate or not predicate, got %s" % str(predicate)
            if predicate in pred_eval:
                raise RuntimeError("Predicate %s already evaluated." % str(predicate))
            pred_eval[predicate] = True
    return pred_eval


def get_pred_change(prior_predicates: List[Predicate], post_predicates: List[Predicate]):
    prior_evals = get_predicate_evaluation(prior_predicates)
    post_evals = get_predicate_evaluation(post_predicates)

    effects = []
    removed = []
    for predicate, prior_eval in prior_evals.items():
        post_eval = post_evals.get(predicate, None)
        if post_eval is None:
            # if the post list does not list the predicate, we assume no effect
            removed.append(predicate)
            continue

        post_evals.pop(predicate)
        if prior_eval == post_eval:
            # evaluation did not change
            pass
        else:
            # evaluation did change, post eval is effect
            if post_eval:
                effects.append(predicate)
            else:
                effects.append(Not(predicate))

    added = [predicate if eval else Not(predicate) for predicate, eval in post_evals.items()]
    return effects, removed, added


def get_effects_from_pred_change(prior_predicates: List[Predicate], post_predicates: List[Predicate]) -> Formula:
    effects, removed, added = get_pred_change(prior_predicates, post_predicates)

    for predicate in added:
        if isinstance(predicate, Predicate):
            effects.append(predicate)
        else:
            # logger.warning(
            #     "False post predicate that is not listed in prior predicates. Undefined behavior. Currently just skipping predicate"
            # )
            pass
            # effects.append(Not(predicate))

    return And(*effects)


def formula_to_list(formula: Union[Formula, List, Tuple, Set]) -> List[Union[Predicate, Not]]:
    if isinstance(formula, (list, tuple, set)):
        return [f for item in formula for f in formula_to_list(item)]
    elif isinstance(formula, (Predicate, Not)):
        return [formula]
    elif isinstance(formula, And):
        return [f for operand in formula.operands for f in formula_to_list(operand)]
    else:
        raise NotImplementedError()


def filter_formula_by_predicates(
    formula: Union[Formula, List, Tuple, Set], known_predicates: Optional[List[str]], *, inverse: bool = False
) -> Formula:
    if known_predicates is None:
        # we actually don't want to filter anything
        if inverse:
            return And()
        else:
            return formula

    filtered = _filter_formula_by_predicates(formula, known_predicates, inverse=inverse)
    if filtered is None:
        return And()
    else:
        return filtered


def _filter_formula_by_predicates(
    formula: Formula, known_predicates: List[str], *, inverse: bool = False
) -> Optional[Formula]:
    if isinstance(formula, (tuple, list, set, frozenset)):
        return _filter_formula_by_predicates(And(*formula), known_predicates, inverse=inverse)
    if isinstance(formula, And):
        filtered = [_filter_formula_by_predicates(o, known_predicates, inverse=inverse) for o in formula.operands]
        filtered = [f for f in filtered if f is not None]
        if len(filtered) == 0:
            return None
        else:
            return And(*filtered)
    elif isinstance(formula, Predicate):
        if (inverse and formula.name not in known_predicates) or (not inverse and formula.name in known_predicates):
            return formula
        else:
            return None
    elif isinstance(formula, Not):
        filtered = _filter_formula_by_predicates(formula.argument, known_predicates, inverse=inverse)
        if filtered is None:
            return None
        else:
            return Not(filtered)
    else:
        raise NotImplementedError()


def combine_predicates(preds_a: Union[List, Formula], preds_b: Union[Formula, List]) -> List[Formula]:
    if isinstance(preds_a, list):
        preds_a = And(*preds_a)
    if isinstance(preds_b, list):
        preds_b = And(*preds_b)

    res = []
    for preds in [preds_a, preds_b]:
        if isinstance(preds, And):
            res += preds.operands
        else:
            assert isinstance(preds, Predicate) or (isinstance(preds, Not) and isinstance(preds.argument, Predicate))
            res += [preds]
    return res


def get_matching_predicates(
    pred_a: List[Formula], pred_b: List[Formula], use_false_for_non_existent: bool = False
) -> List[Formula]:
    """filter `pred_b` to contain only predicates that are in `pred_a`"""
    eval_a = get_predicate_evaluation(pred_a)
    eval_b = get_predicate_evaluation(pred_b)

    filtered_b = []
    for predicate in eval_a.keys():
        if predicate not in eval_b:
            # predicate not in b -> use false if requested
            if use_false_for_non_existent:
                filtered_b.append(Not(predicate))
                continue
            else:
                assert predicate in eval_b, "Predicate %s not in b" % str(predicate)
        filtered_b.append(predicate if eval_b[predicate] else Not(predicate))

    return filtered_b


def equalized_effects(
    prior: List[Formula], post_a: List[Formula], post_b: List[Formula], *, bidirectional: bool = True
) -> Tuple[List[Formula], List[Formula]]:
    post_a = copy.deepcopy(post_a)
    post_b = copy.deepcopy(post_b)

    eval_prior = get_predicate_evaluation(prior)
    effects_a = get_effects_from_pred_change(prior, post_a)
    effects_b = get_effects_from_pred_change(prior, post_b)
    effects_evals_a = get_predicate_evaluation(effects_a)
    effects_evals_b = get_predicate_evaluation(effects_b)
    # evals_a = get_predicate_evaluation(post_a)
    evals_b = get_predicate_evaluation(post_b)

    new_effects_b = []
    for predicate_a in effects_evals_a.keys():

        if predicate_a in effects_evals_b:
            # effects b contains predicate a -> nothing to do
            continue

        eval_b = evals_b.get(predicate_a, None)

        if eval_b is None:
            # not in post of b -> not supported or did not change -> copy from prior
            # raise RuntimeError("Predicate %s not in post_b. This should never happen" % str(predicate_a))
            eval_b = eval_prior.get(predicate_a)

        # assert eval_b is not None, "%s does not exist in b" % str(predicate_a)
        if eval_b is None:
            # this might sometimes happen, e.g. when the predicate is not supported in simulation, but planner tries to plan with it
            new_effects_b.append(Not(predicate_a))

        if eval_b:
            # evals to True
            new_effects_b.append(predicate_a)
        else:
            new_effects_b.append(Not(predicate_a))

    effects_b = combine_predicates(effects_b, new_effects_b)

    if bidirectional:
        # other way around
        _, effects_a = equalized_effects(prior=prior, post_a=post_b, post_b=post_a, bidirectional=False)

        assert pred_len(effects_a) == pred_len(effects_b)

    return effects_a, effects_b


def get_aligned_effects(
    gt_prior: List[Formula],
    gt_post: List[Formula],
    exp_prior: List[Formula],
    exp_post: List[Formula],
    supported_predicates: Optional[List[str]],
):
    gt_prior = flatten_formula_to_predicates(gt_prior)
    gt_post = flatten_formula_to_predicates(gt_post)
    exp_prior = flatten_formula_to_predicates(exp_prior)
    exp_post = flatten_formula_to_predicates(exp_post)

    unsupported_prior = filter_formula_by_predicates(
        formula=exp_prior, known_predicates=supported_predicates, inverse=True
    )
    unsupported_post = filter_formula_by_predicates(
        formula=exp_post, known_predicates=supported_predicates, inverse=True
    )

    gt_prior_w_unsupported = combine_predicates(gt_prior, unsupported_prior)
    gt_post_w_unsupported = combine_predicates(gt_post, unsupported_post)

    if set(gt_prior_w_unsupported) != set(exp_prior):
        diff_1 = list(set(exp_prior) - set(gt_prior_w_unsupported))
        diff_2 = list(set(gt_prior_w_unsupported) - set(exp_prior))
        if len(diff_2) > 0:
            # its ok that diff_1 is not empty. The motion validator might not return all possible groundings, the missing are expected to be false
            raise RuntimeError(
                "Prior predicates do not match. This should never happen.\n- %s\n- %s" % (str(diff_1), str(diff_2))
            )
        # diff_1 = [o for o in exp_prior if o not in gt_prior_w_unsupported]
        # diff_2 = [o for o in gt_prior_w_unsupported if o not in exp_prior]
        # raise RuntimeError(
        #     "Prior predicates do not match. This should never happen.\n- %s\n- %s" % (str(diff_1), str(diff_2))
        # )

    exp_effects, gt_effects = equalized_effects(
        prior=exp_prior,
        post_a=exp_post,
        post_b=gt_post_w_unsupported,
        bidirectional=True,
    )

    assert pred_len(exp_effects) == pred_len(gt_effects)

    return And(*exp_effects), And(*gt_effects)


def pred_len(predicates: Formula) -> int:
    return len(flatten_formula_to_predicates(predicates))


def _flatten_list(lst: List) -> List:
    flattened = []
    for item in lst:
        if isinstance(item, list):
            flattened.extend(_flatten_list(item))
        else:
            flattened.append(item)
    return flattened


def get_list_of_predicates(formula) -> list[Union[Predicate, Not]]:
    if isinstance(formula, And):
        formula = formula.operands
    elif isinstance(formula, (list, tuple, set, frozenset)):
        formula = list(formula)
    else:
        assert isinstance(formula, (Not, Predicate))
        formula = [formula]
    return formula


def get_variables_in_term(term: Formula) -> List[Variable]:
    if isinstance(term, Predicate):
        return list(set(_flatten_list(map(get_variables_in_term, term.terms))))
    elif isinstance(term, UnaryOp):
        return get_variables_in_term(term.argument)
    elif isinstance(term, BinaryOp):
        return list(set(_flatten_list(map(get_variables_in_term, term.operands))))
    elif isinstance(term, Variable):
        return [term]
    elif isinstance(term, Constant):
        return []
    elif isinstance(term, EqualTo):
        return list(set(_flatten_list(map(get_variables_in_term, [term.left, term.right]))))
    elif isinstance(term, (ExistsCondition, ForallCondition)):
        new_vars = [t.name for t in term.variables]
        vars_used_in_effect = get_variables_in_term(term.condition)
        global_vars = [t for t in vars_used_in_effect if t.name not in new_vars]
        return global_vars
    elif isinstance(term, When):
        cond_vars = [t for t in get_variables_in_term(term.condition)]
        effect_vars = get_variables_in_term(term.effect)
        return list(set(cond_vars) & set(effect_vars))
    else:
        raise NotImplementedError(str(term))


def get_constants_in_term(term: Formula) -> List[Constant]:
    if isinstance(term, list):
        return list(set(_flatten_list(map(get_constants_in_term, term))))
    if isinstance(term, Predicate):
        return list(set(_flatten_list(map(get_constants_in_term, term.terms))))
    elif isinstance(term, UnaryOp):
        return get_constants_in_term(term.argument)
    elif isinstance(term, BinaryOp):
        return list(set(_flatten_list(map(get_constants_in_term, term.operands))))
    elif isinstance(term, Variable):
        return []
    elif isinstance(term, Constant):
        return [term]
    else:
        raise NotImplementedError(str(term))


def filter_formula_by_constants(formula: Formula, known_constants: List[str]) -> Formula:
    if isinstance(formula, And):
        filtered = [filter_formula_by_constants(o, known_constants) for o in formula.operands]
        filtered = [f for f in filtered if f is not None]
        if len(filtered) == 0:
            return None
        else:
            return And(*filtered)
    elif isinstance(formula, Predicate):
        if any(t.name not in known_constants for t in formula.terms):
            return None
        else:
            return formula
    elif isinstance(formula, Not):
        filtered = filter_formula_by_constants(formula.argument, known_constants)
        if filtered is None:
            return None
        else:
            return Not(filtered)
    else:
        raise NotImplementedError()


def get_change_str(init: Formula, final: Formula) -> str:
    init_preds = get_predicate_evaluation(init)
    final_preds = get_predicate_evaluation(final)

    preds = [p for p, e in init_preds.items() if e != final_preds.get(p, e)]

    changes = []
    for p in preds:
        init_state = init_preds[p]
        final_state = final_preds[p]
        changes.append(f"- {p}: {init_state} -> {final_state}")

    return "\n".join(changes)
