import numpy as np





class State(np.ndarray):

    def __new__(cls,
                state_array: np.ndarray,
                task_id: int):
        """
        Create a state vector that acts as a numpy array, but contains the task ID, as well as any extra info.
        :param state_array: the state vector
        :param task_id: the ID of the task from which the state originated
        """
        # Input array is an already formed ndarray instance
        # We first cast to be our class type
        obj = np.asarray(state_array).view(cls)
        # add the new attribute to the created instance
        obj._task_id = task_id
        # Finally, we must return the newly created object:
        return obj

    def __array_finalize__(self, obj):
        # see InfoArray.__array_finalize__ for comments
        if obj is None: return
        self._info = getattr(obj, 'info', None)
        self._task_id = getattr(obj, 'task_id', None)

    @property
    def task_id(self):
        return self._task_id

    def __str__(self):
        return np.array2string(self, max_line_width=100000) + ':' + str(self.task_id)

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        args = []
        in_no = []
        for i, input_ in enumerate(inputs):
            if isinstance(input_, State):
                in_no.append(i)
                args.append(input_.view(np.ndarray))
            else:
                args.append(input_)

        outputs = kwargs.pop('out', None)
        out_no = []
        if outputs:
            out_args = []
            for j, output in enumerate(outputs):
                if isinstance(output, State):
                    out_no.append(j)
                    out_args.append(output.view(np.ndarray))
                else:
                    out_args.append(output)
            kwargs['out'] = tuple(out_args)
        else:
            outputs = (None,) * ufunc.nout

        if in_no:
            task_id = in_no[0]
        # if out_no:
        #     task_id['outputs'] = out_no[0]

        results = super(State, self).__array_ufunc__(ufunc, method,
                                                 *args, **kwargs)
        if results is NotImplemented:
            return NotImplemented

        if method == 'at':
            if isinstance(inputs[0], State):
                inputs[0]._task_id = task_id
            return

        if ufunc.nout == 1:
            results = (results,)

        results = tuple((np.asarray(result).view(State)
                         if output is None else output)
                        for result, output in zip(results, outputs))
        if results and isinstance(results[0], State):
            results[0]._task_id = task_id

        return results[0] if len(results) == 1 else results

    def concatenate(self, other):
        return State(np.concatenate((self, other)), self.task_id)

    def copy(self):
        return State(np.copy(self), self.task_id)

    def __eq__(self, other):
        if isinstance(other, State):
            return np.array_equal(self, other)
        return NotImplemented

    @staticmethod
    def from_string(str):
        t = str.split(':')
        a = t[0][1: -2].strip().replace(',', '')
        return State(np.fromstring(a, sep=' '), int(t[1]))

    @staticmethod
    def to_array(states: list):
        arr = np.empty((len(states),), dtype=object)
        arr[:] = states
        return arr

    @staticmethod
    def to_np_array(states: np.ndarray):
        return np.array([np.array(x) for x in states])

    def __reduce__(self):
        # Get the parent's __reduce__ tuple
        pickled_state = super().__reduce__()
        # Create our own tuple to pass to __setstate__
        new_state = pickled_state[2] + (self.task_id,)
        # Return a tuple that replaces the parent's __setstate__ tuple with our own
        return pickled_state[0], pickled_state[1], new_state

    def __setstate__(self, state):
        self._task_id = state[-1]  # Set the info attribute
        # Call the parent's __setstate__ with the other tuple elements.
        super().__setstate__(state[0:-1])
