from collections import defaultdict
from itertools import product
from typing import Dict, List, Tuple

from pddl.core import And, Constant, Formula, Predicate
from pddl.logic.base import Not

from furniture_bench_api.api.api_predicates import (
    StateArmAbovePartPredicate,
    StateArmAtPartPredicate,
    StateArmEmptyPredicate,
    StateAssembledPredicate,
    StateGripperOpenPredicate,
    StateHoldingPredicate,
    StateOnTablePredicate,
    StatePredicate,
    StateTouchingPredicate,
)
from furniture_bench_api.furniture_bench_environment import FurnitureBenchEnvironment

supported_predicates: Dict[str, StatePredicate] = {
    "touching": StateTouchingPredicate(),
    "inserted": StateTouchingPredicate(),
    "assembled": StateAssembledPredicate(),
    "gripper-near-part": StateArmAbovePartPredicate(),
    "gripper-above-part": StateArmAbovePartPredicate(),
    "gripper-positioned-for-grasp": StateArmAtPartPredicate(),
    "gripper-empty": StateArmEmptyPredicate(),
    "gripper-open": StateGripperOpenPredicate(),
    # "holding": StateHoldingPredicate(),
    "gripper-grasps": StateHoldingPredicate(),
    "on-table": StateOnTablePredicate(),
}


class StateValidator:

    def __init__(self, env: FurnitureBenchEnvironment):
        self.env = env

    def get_supported_predicates(self) -> Dict[str, StatePredicate]:
        return supported_predicates.copy()

    def validate_predicate(self, predicate: Predicate, *, skip_unsupported_predicates: bool = True) -> Tuple[bool, str]:
        predicate_name = predicate.name
        predicate_args = [p.name for p in predicate.terms]

        predicate = supported_predicates.get(predicate_name, None)
        if predicate is None:
            if skip_unsupported_predicates:
                print("Skipping unsupported predicate '%s'" % predicate_name)
                return True, ""
            else:
                raise ValueError("Unsupported predicate %s" % predicate_name)

        try:
            is_valid, explanation = predicate.validate(self.env, *predicate_args)
        except TypeError as e:
            print(f"Type error in predicate '{predicate_name}' with args {predicate_args}: {e}")
            raise e
        return is_valid, explanation

    def validate_state(self, state: Formula, *, skip_unsupported_predicates: bool = True) -> Tuple[bool, str]:
        if isinstance(state, And):
            is_valid = True
            desc = []
            for op in state.operands:
                op_valid, op_desc = self.validate_state(op, skip_unsupported_predicates=skip_unsupported_predicates)
                is_valid = is_valid and op_valid
                desc.append(op_desc)
            return is_valid, And(*desc)
        elif isinstance(state, Predicate):
            is_valid, op = self.validate_predicate(state, skip_unsupported_predicates=skip_unsupported_predicates)
            if not is_valid:
                return is_valid, Not(state)
            else:
                return is_valid, state
        elif isinstance(state, Not):
            is_valid, desc = self.validate_state(
                state.argument, skip_unsupported_predicates=skip_unsupported_predicates
            )
            return not is_valid, desc
        else:
            raise NotImplementedError()

    def get_predicates_evaluation(self, known_predicates: List[Predicate], objects: List[Constant]) -> List[Predicate]:
        objects_per_type = defaultdict(list)
        for object in objects:
            objects_per_type[object.type_tag].append(object.name)

        predicates_evaluation = []
        for known_predicate in known_predicates:
            assert all(len(t.type_tags) == 1 for t in known_predicate.terms)
            term_types = [list(term.type_tags)[0] for term in known_predicate.terms]
            objects_per_arg = [objects_per_type[term_type] for term_type in term_types]
            for args in list(product(*objects_per_arg)):
                if len(args) != len(set(args)):
                    print("predicates with multiple times same args are ignored")
                    continue
                predicate = Predicate(
                    known_predicate.name,
                    *[Constant(name=arg, type_tag=term_type) for arg, term_type in zip(args, term_types)],
                )
                try:
                    is_valid, _ = self.validate_predicate(predicate, skip_unsupported_predicates=False)
                except ValueError:
                    # predicate not supported
                    # TODO: later, all predicates given in initial domain should be grounded
                    continue

                if is_valid:
                    predicates_evaluation.append(predicate)
                else:
                    predicates_evaluation.append(Not(predicate))
        return predicates_evaluation
