import numpy as np

from warnings import warn

from symbols.domain.domain import Domain
from symbols.symbols.distribution_symbol import DistributionSymbol
from symbols.symbols.kde import KernelDensityEstimator



NOT_FAILED = 'notfailed' # KernelDensityEstimator.not_failed()


class Schema:

    def __init__(self, option, partition):
        """
        Create a new action schema
        :param option:  the option identifier
        :param partition: the partition identifier
        """
        self.option = option
        self.partition = partition
        self._preconditions = list()
        # self._conditional_effects = list()
        self._effects = list()
        self._id = 0
        self._precondition_conjunction = None
        self.action_descriptor = None

    def set_id(self, id):
        self._id = id

    def add_precondition(self, symbols):
        """
        Add a new precondition to the schema
        :param symbols: a conjunction of symbols in a list
        """
        if len(self._preconditions) > 0:
            warn("Not allowing disjunctions in preconditions (unlike ADL). Ignoring...")
        else:
            self._preconditions.append(symbols)
            self._precondition_conjunction = DistributionSymbol.conjoin(
                [x for x in self._preconditions[0] if str(x) != str(NOT_FAILED)])

    def add_effect(self, conjunction):
        """
        Add a new effect to the schema
        :param conjunction: the action rule
        """
        self._effects.append(conjunction)

    # def add_conditional_effect(self, rule: Rule):
    #     """
    #     Add a new conditional effect to the schema
    #     :param rule: the action rule
    #     """
    #     self._effects.append(rule)

    def log_likelihood(self, state):
        return self._precondition_conjunction.score(state)

    def rule_probabilities(self, state):
        """
        Calculate the probability of the state belonging to the distribution represented by each respective effect
        :param state: the state
        :return: the list of probabilities
        """
        log_likelihoods = [rule.log_likelihood(state) for rule in self._effects]
        # Shove them through a softmax
        e_x = np.exp(log_likelihoods - np.max(log_likelihoods))
        return e_x / e_x.sum()

    def __str__(self):
        name = self.option_name.replace(' ', '-')
        return '\t;Action ' + name + '-partition-' + str(self.partition) + '\n' + \
               '\t(:action ' + name + '_' + str(self._id) + '\n' + \
               '\t :parameters()\n' + \
               '\t :precondition ' + self._precondition_to_str() + '\n' + \
               '\t :effect ' + self._effect_to_str() + '\n' + \
               '\t)'

    @property
    def option_name(self):
        names = ['Walk to', 'Attack', 'Pickup', 'Walk North Door', 'Walk South Door', 'Through Door', 'Craft', 'Open Chest', 'Toggle Door']
        return names[self.option]
        # if self.action_descriptor is None:
        #     return str(self.option)
        # return self.action_descriptor(self.option)

    def _precondition_to_str(self):
        if len(self._preconditions) > 1:
            return '(' + ' or '.join([Schema.conjunction_to_str(p) for p in self._preconditions]) + ')'

        return Schema.conjunction_to_str(self._preconditions[0])

    def _effect_to_str(self):

        if len(self._effects) == 1:
            return str(self._effects[0])

        #TODO: hack - fix this - just checking PDDL validity
        # if self._effects[0].probability > 0.50:
        #     return str(self._effects[0])
        # return str(self._effects[1])

        return '(probabilistic\n' + '\n'.join(['\t\t\t\t' + str(rule) for rule in self._effects]) + '\n\t\t\t )'

    @staticmethod
    def conjunction_to_str(symbols):

        if len(symbols) == 1:
            return ' '.join(['(' + str(symbol) + ')' for symbol in symbols])

        return '(and ' + ' '.join(['(' + str(symbol) + ')' for symbol in symbols]) + ')'

    @property
    def rules(self):
        return self._effects
    #
    #
    # @staticmethod
    # def evaluate(domain: Domain,
    #              option,
    #              partitioned_options,
    #              n_samples=100):
    #
    #     for _ in range(n_samples):
    #         state, next_state = Schema._sample_domain(domain, option)
    #
    #         for partition_id, candidates in partitioned_options.items():
    #             print("Checking partition " + str(partition_id))
    #
    #             for candidate in candidate:
    #






    def evaluate(self, domain: Domain, n_samples):

        prob = 0

        for effect in self._effects:

            if effect._conjunction is None:
                continue

            states = self._precondition_conjunction.get_samples(n_samples)
            next = np.array([self._sample_domain(domain, state, self._precondition_conjunction.get_mask()) for state in states])
            # next_states = next[::2]
            # next_rewards = next[1::2]

            score = np.mean([effect.log_likelihood(state) for state in next])
            p = 0
            # score_rewards = abs(np.mean(next_rewards) - effect.reward)

    def _sample_domain(self, domain: Domain, state, mask):

        s = np.random.uniform(0, 1, len(domain.current_state))
        s[mask] = state
        domain.init(s)
        if domain.can_execute(self.option):
            st, reward, _, _ = domain.step(self.option)
            return st
        return list(s)

        # if domain.can_execute(self.option):
        #     state, reward, _, _ = domain.step(self.option)
        #     return state, reward
        # return list(s), -1

