import numpy as np




class TransitionSample:
    """
    A class holding a sample of a transition, which is simply a (state, option, reward, state) tuple.
    """

    def __init__(self,
                 state: np.ndarray,
                 observation: np.ndarray,
                 option: int,
                 object_id: int,
                 reward: float,
                 next_state: np.ndarray,
                 next_observation: np.ndarray,
                 view='problem'):

        self._option = option
        self._object_id = object_id
        self._reward = reward

        if view == 'problem':
            self._state = state
            self._next_state = next_state
            self._observation = observation
            self._next_observation = next_observation
        elif view == 'agent':
            self._state = observation
            self._next_state = next_observation
            self._observation = state
            self._next_observation = next_state
        self.episode = None

    @property
    def object_id(self):
        return self._object_id

    @property
    def state(self):
        return self._state

    @property
    def option(self):
        return self._option

    @property
    def reward(self):
        return self._reward

    @property
    def next_state(self):
        return self._next_state

    @property
    def observation(self):
        return self._observation

    @property
    def next_observation(self):
        return self._next_observation


    @property
    def flat_state(self):
        return np.concatenate(self.state).ravel()

    @property
    def flat_observation(self):
        return np.concatenate(self.observation).ravel()


    @property
    def flat_next_state(self):
        return np.concatenate(self.next_state).ravel()

    @property
    def flat_next_observation(self):
        return np.concatenate(self.next_observation).ravel()

    @property
    def mask(self):
        return np.array([j for j in range(0, len(self.state)) if not np.array_equal(self.state[j], self.next_state[j])])

    @property
    def flat_mask(self):
        # TODO
        s = np.concatenate(self.state).ravel()
        s_prime = np.concatenate(self.next_state).ravel()
        # s = self.state
        # s_prime = self.next_state
        return np.array([j for j in range(0, len(s)) if s[j] != s_prime[j]])

