from symbols.link.link import Link
from symbols.symbols.learned_operator import LearnedOperator
import numpy as np

def _rescale(logit):
    odds = np.exp(logit)
    return odds / (1 + odds)


class TransferableOperatorData:

    # like learned operator, but with info on what objects are which type (so we can transfer)
    def __init__(self, operator_data, types):
        self.operator_data = operator_data
        self.classes = types
        self.mask = None

    def reset(self):
        self.mask = None

        self.operator_data.subpartitions = self.operator_data.partitioned_option.subpartition()

        if len(self.operator_data.subpartitions) == 0:
            #  Then we need to ground the precondition only
            states = np.array([self.operator_data.partitioned_option.extract_prob_space(self.operator_data.partitioned_option.observations[i]) for i in
                               range(self.operator_data.partitioned_option.states.shape[0])])
            self.operator_data.links = [Link(states, None)]
        else:
            self.operator_data.links = [Link(subpartition.states, subpartition.next_states) for subpartition in self.operator_data.subpartitions]

    @property
    def links(self):
        return self.operator_data.links

    def add_problem_symbols(self, precondition_idx, effect_idx):

        # print("Adding p_symbol{} and p_symbol{}".format(precondition_idx, effect_idx))

        if effect_idx != -1:
            self.operator_data.raw_operator.link(precondition_idx, effect_idx)
        else:
            self.operator_data.raw_operator.link(precondition_idx, None)

        # self.operator_data.raw_operator.instantiate(self.full_mask)

    def instantiate(self, mask):
        self.mask = mask

    @property
    def option(self):
        return self.operator_data.option

    @property
    def partition(self):
        return self.operator_data.partition

    def score(self, state, option, next_state):

        if option != self.option:
            return 0, 0

        operator = self.operator_data.raw_operator

        # TODO sort out masks
        mask = operator.precondition.mask
        types = [self.classes[i] for i in mask]

        d = state[mask]

        state_prob = operator.precondition.probability(d)
        max_eff = 0
        next_state_prob = 0
        for i, q in enumerate(operator.list_probabilities):
            log_p = operator.list_effects[i].score(next_state)
            eff = _rescale(log_p)
            max_eff = max(max_eff, eff)
            next_state_prob += q * eff
        # total_prob = state_prob * next_state_prob
        total_score = state_prob * max_eff
        return state_prob, total_score

