import numpy as np
from train.behavioral_cloning.spaces.action_spaces import ENUM_ACTIONS, ENUM_ACTION_OPTIONS


class InputSpace(object):
    def __init__(self):
        pass

    def prepare(self, pov: np.ndarray, binary_actions: np.ndarray, camera_actions: np.ndarray,
                enum_actions: np.ndarray, inventory: np.ndarray, equipped_items: np.ndarray,
                rewards: np.ndarray, steps_remaining: np.ndarray):
        pass


class SingleFrameWithBinaryActionAndContinuousCameraSequence(InputSpace):
    def __init__(self, sequence_len: int):
        super().__init__()
        self.sequence_len = sequence_len

    def prepare(self, pov: np.ndarray, binary_actions: np.ndarray, camera_actions: np.ndarray,
                enum_actions: np.ndarray, inventory: np.ndarray, equipped_items: np.ndarray,
                rewards: np.ndarray, steps_remaining: np.ndarray):
        """
            Prepare inputs for processing with networks.
            :param pov: S, C, H, W
            :param binary_actions: S, 8
            :param camera_actions: S, 2
            :param enum_actions: S, 5
            :param inventory: S, 18
            :param equipped_items: S, 1
            :param rewards: S, 1

            :return: input-dictionary (has to match the selected model and training strategy)
            """
        inputs = dict()
        inputs["pov"] = pov[:self.sequence_len]
        inputs["binary_actions"] = binary_actions[:self.sequence_len]
        inputs["camera_actions"] = camera_actions[:self.sequence_len]
        inputs["rewards"] = rewards[:self.sequence_len]

        if enum_actions is not None:
            for i, enum_act in enumerate(ENUM_ACTIONS):
                inputs["act_" + enum_act] = self._enum_to_one_hot(enum_actions[:self.sequence_len, i], enum_act)
                assert inputs["act_" + enum_act].shape[0] == self.sequence_len
        if inventory is not None:
            inputs["inventory"] = inventory[:self.sequence_len]
            assert inputs["inventory"].shape[0] == self.sequence_len
        if equipped_items is not None:
            inputs["equipped_items"] = equipped_items[:self.sequence_len]
            assert inputs["equipped_items"].shape[0] == self.sequence_len

        assert inputs["pov"].shape[0] == self.sequence_len
        assert inputs["binary_actions"].shape[0] == self.sequence_len
        assert inputs["camera_actions"].shape[0] == self.sequence_len
        assert inputs["rewards"].shape[0] == self.sequence_len

        return inputs

    def _enum_to_one_hot(self, values, enum_key):
        zeros = np.zeros((len(values), len(ENUM_ACTION_OPTIONS[enum_key])), dtype=np.float32)
        for i, val in enumerate(values.astype(np.int)):
            zeros[i, val] = 1
        return zeros


class SingleFrameWithBinaryActionAndBinnedCameraSequence(InputSpace):
    def __init__(self, sequence_len: int, bins):
        super().__init__()
        self.sequence_len = sequence_len
        self.bins = bins
        self.n_classes = len(self.bins) + 1

    def prepare(self, pov: np.ndarray, binary_actions: np.ndarray, camera_actions: np.ndarray,
                enum_actions: np.ndarray, inventory: np.ndarray, equipped_items: np.ndarray, rewards: np.ndarray):
        """
            Prepare inputs for processing with networks.
            :param pov: S, C, H, W
            :param binary_actions: S, 8
            :param camera_actions: S, 2
            :param enum_actions: S, 5
            :param inventory: S, 18
            :param equipped_items: S, 1
            :param rewards: S, 1
            :return: input-dictionary (has to match the selected model and training strategy)
            """
        inputs = dict()
        inputs["pov"] = pov[:self.sequence_len]
        inputs["binary_actions"] = binary_actions[:self.sequence_len]
        inputs["camera_actions"] = self.prepare_input(camera_actions[:self.sequence_len])
        inputs["enum_actions"] = enum_actions[:self.sequence_len]
        inputs["inventory"] = inventory[:self.sequence_len]
        inputs["equipped_items"] = equipped_items[:self.sequence_len]
        inputs["rewards"] = rewards[:self.sequence_len]
        assert inputs["pov"].shape[0] == self.sequence_len
        assert inputs["binary_actions"].shape[0] == self.sequence_len
        assert inputs["camera_actions"].shape[0] == self.sequence_len
        assert inputs["enum_actions"].shape[0] == self.sequence_len
        assert inputs["inventory"].shape[0] == self.sequence_len
        assert inputs["equipped_items"].shape[0] == self.sequence_len
        assert inputs["rewards"].shape[0] == self.sequence_len

        return inputs

    def prepare_input(self, camera_actions: np.ndarray):
        # bin camera in both directions
        cam_ud = np.full((camera_actions.shape[0], self.n_classes,), fill_value=0.0, dtype=np.float32)
        idx_ud = np.apply_along_axis(np.searchsorted, -1, self.bins, camera_actions[:, 0])
        cam_ud[:, idx_ud] = 1.0

        cam_lr = np.full((camera_actions.shape[0], self.n_classes,), fill_value=0.0, dtype=np.float32)
        idx_lr = np.apply_along_axis(np.searchsorted, -1, self.bins, camera_actions[:, 1])
        cam_lr[:, idx_lr] = 1.0

        return np.concatenate((cam_ud, cam_lr), axis=1)
