from sklearn.svm import OneClassSVM

from symbols.symbols.learned_operator import LearnedOperator
import numpy as np
import itertools

from symbols.utils import samples2np


def _rescale(logit):
    odds = np.exp(logit)
    return odds / (1 + odds)


class DummyTransferableLearnedOperator:

    def __init__(self, operator: LearnedOperator):
        self.operator = operator
        self.option = None
        self._ocs = dict()

    def score(self, state, option, next_state, _, __):
        best = 0
        state_prob = self.operator.precondition.probability(state, use_mask=True)
        max_eff = 0
        next_state_prob = 0
        for i, q in enumerate(self.operator.list_probabilities):
            eff = self.score_effect(i, self.operator.list_effects[i], next_state)
            max_eff = max(max_eff, eff)
            next_state_prob += q * eff
        # total_prob = state_prob * next_state_prob
        total_score = state_prob * max_eff
        if total_score > best:
            best = total_score
        return state_prob, best

    def score_effect(self, i, effect, next_state):
        if i not in self._ocs:
            X = effect.sample(100)
            X = samples2np(X)
            self._ocs[i] = OneClassSVM(gamma='auto').fit(X)
        temp = next_state[effect.mask]
        Y = np.zeros(shape=(sum(x.shape[0] for x in temp),))
        idx = 0
        for x in temp:
            Y[idx:idx + x.shape[0]] = x
            idx += x.shape[0]
        sign = self._ocs[i].predict(np.reshape(Y, (1, -1)))[0]
        return 1 if sign > 0 else 0

        # log_p = effect.score(next_state, mask=eff_mask)
        # return _rescale(log_p)

class TransferableLearnedOperator:
    # like learned operator, but with info on what objects are which type (so we can transfer)
    def __init__(self, operator: LearnedOperator, types):
        self.operator = operator
        self.classes = types
        self.linking = dict()

        m = operator.precondition.mask
        self.precondition_mask = [types[x] for x in m]
        if len(operator.list_effects) > 1:
            raise ValueError
        m = operator.list_effects[0].mask
        self.effect_mask = [types[x] for x in m]

        self.mask = None

        self._ocs = dict()


    def reset(self):
        self.linking = dict()
        self.mask = None

    def link(self, pre, eff):
        self.linking[pre] = eff

    def instantiate(self, mask):
        self.mask = mask

    @property
    def option(self):
        return self.operator.option

    @property
    def partition(self):
        return self.operator.partition

    @partition.setter
    def partition(self, value):
        self.operator.partition = value


    def _find_types(self, types, current):
        mask = list()
        for target in types:
            temp = [i for i, type in current.items() if type == target]
            mask.append(temp)
        return list(itertools.product(*mask))


    def score_effect(self, i, effect, next_state, eff_mask):
        if i not in self._ocs:
            X = effect.sample(100)
            X = samples2np(X)
            self._ocs[i] = OneClassSVM(gamma='auto').fit(X)
        temp = next_state[effect.mask]
        Y = np.zeros(shape=(sum(x.shape[0] for x in temp),))
        idx = 0
        for x in temp:
            Y[idx:idx + x.shape[0]] = x
            idx += x.shape[0]
        sign = self._ocs[i].predict(np.reshape(Y, (1, -1)))[0]
        return 1 if sign > 0 else 0

        # log_p = effect.score(next_state, mask=eff_mask)
        # return _rescale(log_p)



    def score(self, state, option, next_state, current_types, object_idx):

        if option != self.option:
            return 0, 0

        real_mask = set([j for j in range(0, len(state)) if not np.array_equal(state[j], next_state[j])] + [object_idx])

        mask = self.operator.precondition.mask
        types = [self.classes[i] for i in mask]
        state_prob = 0
        best = 0
        for mask in self._find_types(types, current_types):  # find all state variables such that they have same type!
            mask = list(mask)
            if not real_mask.issubset(set(mask)) or len(mask) != len(set(mask)):
                continue
            d = state[mask]
            try:
                prob = self.operator.precondition.probability(d, use_mask=False)
            except:
                prob = 0

            max_eff = 0
            next_state_prob = 0
            for i, q in enumerate(self.operator.list_probabilities):

                types = [self.classes[i] for i in self.operator.list_effects[i].mask]
                temp = 0
                for eff_mask in self._find_types(types, current_types):
                    eff_mask = list(eff_mask)
                    if not set(eff_mask).issubset(set(mask)) or len(eff_mask) != len(set(eff_mask)):
                        continue
                    eff = self.score_effect(i, self.operator.list_effects[i], next_state, eff_mask)
                    temp = max(temp, eff)
                eff = temp
                max_eff = max(max_eff, eff)
                next_state_prob += q * eff
            # total_prob = state_prob * next_state_prob
            total_score = prob * max_eff

            if total_score > best:
                best = total_score
                state_prob = prob

        return state_prob, best

    # def score(self, state, option, next_state, current_types, object_idx):
    #
    #     if option != self.option:
    #         return 0, 0
    #
    #     real_mask = set([j for j in range(0, len(state)) if not np.array_equal(state[j], next_state[j])] + [object_idx])
    #
    #     mask = self.operator.precondition.mask
    #
    #     # if object_idx not in mask:
    #     #     raise ValueError
    #
    #     types = [self.classes[i] for i in mask]
    #
    #     state_prob = 0
    #     best = 0
    #     for mask in self._find_types(types, current_types):  # find all state variables such that they have same type!
    #         mask = list(mask)
    #         if not real_mask.issubset(set(mask)) or len(mask) != len(set(mask)):
    #             continue
    #         d = state[mask]
    #         try:
    #             prob = self.operator.precondition.probability(d, use_mask=False)
    #         except:
    #             prob = 0
    #
    #         max_eff = 0
    #         next_state_prob = 0
    #         for i, q in enumerate(self.operator.list_probabilities):
    #
    #             change = set([j for j in range(0, len(state)) if not np.array_equal(state[j], next_state[j])])
    #             types = [self.classes[i] for i in self.operator.list_effects[i].mask]
    #             temp = 0
    #             for mask in self._find_types(types, current_types):
    #                 mask = list(mask)
    #                 if not change.issubset(set(mask)) or len(mask) != len(set(mask)):
    #                     continue
    #                 log_p = self.operator.list_effects[i].score(next_state, mask=mask)
    #                 eff = _rescale(log_p)
    #                 temp = max(temp, eff)
    #             eff = temp
    #             max_eff = max(max_eff, eff)
    #             next_state_prob += q * eff
    #         # total_prob = state_prob * next_state_prob
    #         total_score = prob * max_eff
    #
    #         if total_score > best:
    #             best = total_score
    #             state_prob = prob
    #
    #     return state_prob, best

    def proba(self, states, psymbol=None):
        if psymbol is not None and psymbol not in self.linking:
            return 0
        mask = self.operator.precondition.mask
        types = [self.classes[i] for i in mask]
        state_prob = 0
        for state in states:
            d = state[mask]
            state_prob += self.operator.precondition.probability(d, use_mask=False)
        return state_prob / len(states)

    def forward(self, states, psymbol=None):

        if psymbol is None:
            next_p = None
        else:
            if psymbol not in self.linking:
                raise ValueError
            next_p = self.linking[psymbol]
            if next_p is None:
                next_p = psymbol

        mask = self.operator.list_effects[0].mask

        next = self.operator.list_effects[0].sample(100)
        new = np.copy(states)
        for i, state in enumerate(new):
            for j, m in enumerate(mask):
            # for j, m in enumerate(self.effect_mask):
                new[i][m] = next[i][j]
        return new, next_p
