import warnings
from warnings import warn

import numpy as np
import itertools

from symbols.pddl.conjunction import Conjunction
from symbols.pddl.schema import Schema, NOT_FAILED
from symbols.symbols.distribution_symbol import DistributionSymbol


class PPDDLSymbolWrapper:
    def __init__(self,
                 sign,
                 symbol):
        self.sign = sign
        self.symbol = symbol

    def __str__(self):
        if self.sign < 0:
            return 'not (' + str(self.symbol) + ')'
        return str(self.symbol)


def build_ppddl(factors, operators, proposition_list, option_symbols_indices, verbose=True):
    """
    Create the PPDDL definition of the domain
    :param factors: the factors
    :param operators: the learned symbols
    :param proposition_list: the propositions
    :param option_symbols_indices: the symbols for each operator
    :param verbose: whether to print information to screen
    :return: a string representing the PPDDL description
    """

    schemata = list()
    # Examine preconditions.
    for i, operator in enumerate(operators):
        precondition = operator.precondition
        if precondition is None:
            print("WTF!", i)
            continue
        mask = precondition.mask
        option_factors = _mask_to_factors(mask, factors)

        if verbose:
            print("Partitioned option " + str(i + 1) + " of " + str(len(operators)))
        schema = _build_forward_model(operator, option_factors, factors, proposition_list, option_symbols_indices[i])
        schemata += schema

    return schemata


def _mask_to_factors(mask, factors):
    f_list = []
    for index, f in enumerate(factors):
        found = False
        for m in mask:
            if m in f:
                found = True
                break
        if found:
            f_list.append(index)
    return f_list


def _masks_overlap(symbols):
    variables = np.hstack([symbol.mask for symbol in symbols])
    return np.unique(variables).size != len(variables)

def filter_probs(mean_image_prob, max_prob, min_prob, mean_prob, var_prob, threshold):

    if mean_prob >= threshold: # or mean_image_prob >= threshold:
        return max(mean_image_prob, mean_prob)
    if mean_prob < 0.02 or mean_image_prob < 0.02:
        return 0
    if mean_prob + var_prob > threshold:
        return max(mean_image_prob, mean_image_prob, max_prob)
    return mean_prob


def _build_forward_model(operator, option_factors, factors, propositions, option_symbols_indices, verbose=True):
    """
    Build a schema for the given action operator
    :param operator: the symbol for the current action operator
    :param option_factors: the set of factors containing state variables modified by the action
    :param factors: a list of all factors
    :param propositions: a list of all propositions
    :param option_symbols_indices: the indices of the propositions that the effect of the action refers to
    :return: the schema for the current action
    """

    candidates = []  # candidates are all possible propositions that we need to consider
    schemata = list()

    # Get all symbols whose mask overlaps with the factors
    for f in option_factors:
        s_list = []
        for s in propositions:
            fact = factors[f]
            if set(s.mask) == set(fact):
                s_list.append(s)
        candidates.append(s_list)


    threshold = .8

    round_down_thresh = 0.4

    found = False
    combinations = list(itertools.product(*candidates))
    for use_mean_image in [False, True]:
          # all possible combinations
        for candidates in combinations:

            if _masks_overlap(candidates):
                continue

            #todo fix, slow
            prefactors = list()
            for z in candidates:
                prefactors += list(z.mask)

            prefactors = set(prefactors)

            if len(candidates) == 1:
                grounding = candidates[0]
            else:
                grounding = DistributionSymbol.conjoin(candidates)

            mean_image_prob, max_prob, min_prob, mean_prob, var_prob = grounding.probability_in_set(operator.precondition)
            execution_prob = mean_image_prob if use_mean_image else mean_prob


            # print(mean_image_prob, max_prob, min_prob, mean_prob, var_prob, execution_prob)

            # execution_prob = filter_probs(mean_image_prob, max_prob, min_prob, mean_prob, var_prob, threshold)

            # if verbose:
            #     print(str(grounding) + " : " + str(round(execution_prob, 4)))

            if execution_prob < 1 - round_down_thresh:  # discard if less than % chance of success
            # if execution_prob < 0.5:
                if mean_prob > 0.15 and max_prob > 0.99:
                    warnings.warn("Using patch!")
                    execution_prob = max_prob
                else:
                    continue

            print(mean_image_prob, max_prob, min_prob, mean_prob, var_prob, execution_prob)

            found = True
            schema = Schema(operator.option, operator.partition)
            preconditions = list(candidates[:])
            preconditions.append(NOT_FAILED)
            schema.add_precondition(preconditions)

            if execution_prob < threshold:  # if less than % chance of success, add failure option
                schema.add_effect(Conjunction([PPDDLSymbolWrapper(-1, NOT_FAILED)], 0, 1 - execution_prob))

            else:
                execution_prob = 1  # round up to 1

            for i in range(0, len(operator.list_probabilities)):
                outcome_prob = round(operator.list_probabilities[i], 4) * execution_prob  # scale by success probability
                # rew_sym = operator.list_rewards[i]
                # rew = rew_sym.expected_reward(grounding)

                effect_symbols = list()

                # Positive effects.

                positive = []
                pos_idx = []

                for pos_e in option_symbols_indices[i]:
                    positive.append(propositions[pos_e])
                    if not _masks_overlap(positive):
                        effect_symbols.append(PPDDLSymbolWrapper(1, propositions[pos_e]))
                        pos_idx.append(pos_e)


                # Negative effects.
                neg_e = range(0, len(propositions))

                # Remove positive effects.
                neg_e = [x for x in neg_e if x not in pos_idx]

                # todo no idea here:

                neg_e = [x for x in neg_e if
                         set(propositions[x].mask).issubset(set(operator.list_effects[i].mask))]

                # Filter: in the precondition - only if explicitly mentioned.
                neg_e = [x for x in neg_e if
                         set(propositions[x].mask).issubset(prefactors) and (propositions[x] in candidates)]

                # for n in neg_e:
                #     effect_symbols.append(PPDDLSymbolWrapper(-1, propositions[n]))


                # Point 2, page 20: All propositional symbols with grounding classifier A ⊆ I_o such that factors(A) ⊆
                # factors(o) are set to false.

                for n in neg_e:
                    if set(propositions[n].mask).issubset(set(operator.list_effects[i].mask)):
                        effect_symbols.append(PPDDLSymbolWrapper(-1, propositions[n]))

                # Point 3, page 21: All currently true propositional symbols with grounding classifier B ⊆ I o , where
                # f_bi = factors(B) ∩ factors(o) != ∅ but factors(B) ! ⊆ factors(o), are set to false. For
                # each such B, the proposition with grounding symbol Project(B, f_bi) is set to true.
                for x in neg_e:

                    factors_b = set(propositions[x].mask)
                    factors_o = set(operator.list_effects[i].mask)

                    if not factors_b.isdisjoint(factors_o) and not factors_b.issubset(factors_o):
                        effect_symbols.append(PPDDLSymbolWrapper(-1, propositions[x]))

                        #Todo must be better way of doing this:
                        warn("Doing KL divergence test (slow!)")
                        f_bi = factors_b.intersection(factors_o)
                        y = grounding.integrate_out(f_bi)
                        closest = _find_closest(y, propositions)
                        if closest is not None:
                            effect_symbols.append(PPDDLSymbolWrapper(1, closest))

                # rule = Conjunction(effect_symbols, rew, outcome_prob, display_only=True)
                rule = Conjunction(effect_symbols, None, outcome_prob, display_only=True)
                schema.add_effect(rule)



            print(schema)
            schemata.append(schema)

        if found:
            break
        # if not found:
        #     print("Problem: AAHAHHAHAHAHHAHAHAHHAHAHHAH ")
        # else:
        #     break
    if not found:
        print("Problem: AAHAHHAHAHAHHAHAHAHHAHAHHAH ")
    return schemata


def _find_closest(proposition, propositions):

    dist = float('Inf')
    closest = None
    for candidate in propositions:

        if set(proposition.mask) == set(candidate.mask):
            d = proposition.kl_divergence(candidate)
            if d < dist:
                dist = d
                closest = candidate

    return closest