import numpy as np
import torch
from torch.distributions import Bernoulli, Categorical
from copy import deepcopy

# Notes:

# We have performed Wilcoxon statistical test for actions appearing together, under the assumption that they follow a Poisson distribution
# for the design of our action space.
# For non-imitation learning agents, some actions are performed one after another. For example, "nearBycraft item" and "equip item"
# significantly appear together, and therefore, if we perform action  "nearBycraft wooden pickaxe" we also perform "equip wooden pickaxe"


BINARY_ACTIONS = ["attack", "back", "forward", "jump", "left", "right", "sneak", "sprint"]
ENUM_ACTIONS = ["craft", "equip", "nearbyCraft", "nearbySmelt", "place"]
ENUM_ACTION_OPTIONS = {
    "equip": ["none", "air", "wooden_axe", "wooden_pickaxe", "stone_axe", "stone_pickaxe", "iron_axe", "iron_pickaxe"],
    "place": ["none", "dirt", "stone", "cobblestone", "crafting_table", "furnace", "torch"],
    "craft": ["none", "torch", "stick", "planks", "crafting_table"],
    "nearbyCraft": ["none", "wooden_axe", "wooden_pickaxe", "stone_axe", "stone_pickaxe", "iron_axe", "iron_pickaxe",
                    "furnace"],
    "nearbySmelt": ["none", "iron_ingot", "coal"]
}
CONTINUOUS_ACTIONS = ["camera"]


def stack_actions(env_actions, binary_actions):
    """
    Stack list of action dictionaries to binary matrix with shape (T, A)
        env_actions: list of action dictionaries
    """

    n_actions = len(binary_actions)
    n_steps = len(env_actions)

    action_vec = np.zeros((n_steps, n_actions), dtype=np.float32)
    camera_vec = np.zeros((n_steps, 2), dtype=np.float32)
    for i, env_action in enumerate(env_actions):
        for j, bin_act in enumerate(binary_actions):
            action_vec[i, j] = env_action[bin_act]
        camera_vec[i] = env_actions[i]["camera"].flatten()
    return action_vec, camera_vec


def log_probs_entropy(logits, env_actions, action):
    actions = np.zeros((len(env_actions),), dtype=np.float32)
    for i, env_action in enumerate(env_actions):
        actions[i] = ENUM_ACTION_OPTIONS[action].index(env_action[action])
    dist = Categorical(probs=torch.softmax(logits, dim=-1))
    actions = torch.from_numpy(actions).to(logits.device)
    log_probs = dist.log_prob(actions).unsqueeze(-1)
    entropy = dist.entropy().mean()
    return log_probs, entropy


def stack_enum_action(env_actions):
    """
    Stack list of action dictionaries to enum action matrix with shape (T, A)
        env_actions: list of action dictionaries
    """
    n_actions = 5
    n_steps = len(env_actions)

    action_vec = np.zeros((n_steps, n_actions), dtype=np.float32)
    for i, env_action in enumerate(env_actions):
        action_vec[i, 0] = ENUM_ACTION_OPTIONS["equip"].index(env_action["equip"])
        action_vec[i, 1] = ENUM_ACTION_OPTIONS["place"].index(env_action["place"])
        action_vec[i, 2] = ENUM_ACTION_OPTIONS["craft"].index(env_action["craft"])
        action_vec[i, 3] = ENUM_ACTION_OPTIONS["nearbyCraft"].index(env_action["nearbyCraft"])
        action_vec[i, 4] = ENUM_ACTION_OPTIONS["nearbySmelt"].index(env_action["nearbySmelt"])

    return action_vec


def sample_with_temperature(logits, temperature=0.01):
    u = torch.empty_like(logits).uniform_()
    return torch.argmax(logits - torch.log(-torch.log(u)) * temperature, dim=-1).item()


class ActionSpace(object):
    def __init__(self):
        pass

    def get_shape(self):
        pass

    def prepare_input(self, binary_actions: np.ndarray, camera_actions: np.ndarray, enum_Actions, inventory,
                      equipped_items, reward: np.ndarray, steps_remaining: np.ndarray):
        pass

    def logits_to_dict(self, template, output_dict, **kwargs):
        pass


class MultiBinarySoftmaxCamera(ActionSpace):
    def __init__(self, bins):
        super().__init__()
        self.bins = bins
        self.n_classes = len(self.bins) + 1
        self.bin_centers = [self.bins[0]] + \
                           [np.mean([self.bins[i:i + 2]]) for i in range(len(self.bins) - 1)] + \
                           [self.bins[-1]]
        self.bin_centers = np.array(self.bin_centers, dtype=np.float32)

    def get_shape(self):
        return {"binary_actions": (8,),
                "camera_actions": (2 * self.n_classes,)}

    def prepare_input(self, binary_actions: np.ndarray, camera_actions: np.ndarray,

                      enum_actions: np.ndarray, inventory: np.ndarray,
                      equipped_items: np.ndarray, rewards: np.ndarray,
                      steps_remaining: np.ndarray):

        # bin camera in both directions
        target_ud = self._to_categorical_target(camera_actions[-1, 0])
        target_lr = self._to_categorical_target(camera_actions[-1, 1])

        targets = dict()
        targets["binary_actions"] = binary_actions[-1]
        targets["camera_actions"] = np.concatenate((target_ud, target_lr), axis=0)

        return targets

    def logits_to_dict(self, template, output_dict, **kwargs):
        action_logits = output_dict["logits"]
        camera_logits = output_dict["camera"]
        env_action = self._sample(template, action_logits, camera_logits, deterministic=False)
        return env_action

    def evaluate_actions(self, output_dict, env_actions):
        action_logits = output_dict["logits"]
        camera_logits = output_dict["camera"]

        # compute value output
        value = output_dict["value"] if "value" in output_dict else None

        # compute binary action log probabilities and entropy
        # ---------------------------------------------------
        discrete_action_matrix, camera_actions = stack_actions(env_actions, BINARY_ACTIONS)
        discrete_action_matrix = torch.from_numpy(discrete_action_matrix).to(action_logits.device)

        action_probs = torch.sigmoid(action_logits)
        action_dist = Bernoulli(probs=action_probs)
        action_log_probs = action_dist.log_prob(discrete_action_matrix)
        action_entropy = action_dist.entropy().mean()

        # compute camera log probabilities and entropy
        # --------------------------------------------
        # TODO: fix log_prob computation for soft targets!
        disc_cam_actions_ud = np.array([self._to_categorical_target(v, soft=False).argmax()
                                        for v in camera_actions[:, 0]])
        disc_cam_actions_lr = np.array([self._to_categorical_target(v, soft=False).argmax()
                                        for v in camera_actions[:, 1]])
        disc_cam_actions_ud = torch.from_numpy(disc_cam_actions_ud).to(action_logits.device)
        disc_cam_actions_lr = torch.from_numpy(disc_cam_actions_lr).to(action_logits.device)

        camera_probs_ud = torch.softmax(camera_logits[:, :self.n_classes], dim=1)
        camera_probs_lr = torch.softmax(camera_logits[:, self.n_classes:], dim=1)

        camera_dist_ud = Categorical(probs=camera_probs_ud)
        camera_dist_lr = Categorical(probs=camera_probs_lr)

        camera_log_probs_ud = camera_dist_ud.log_prob(disc_cam_actions_ud).unsqueeze(-1)
        camera_log_probs_lr = camera_dist_lr.log_prob(disc_cam_actions_lr).unsqueeze(-1)
        camera_log_probs = torch.cat((camera_log_probs_ud, camera_log_probs_lr), dim=1)

        camera_entropy_ud = camera_dist_ud.entropy().mean()
        camera_entropy_lr = camera_dist_lr.entropy().mean()
        camera_entropy = (camera_entropy_ud + camera_entropy_lr).mean()

        return value, action_log_probs, camera_log_probs, action_entropy, camera_entropy

    def _sample(self, template, action_logits, camera_logits, deterministic):

        # set action predictions
        action_probs = torch.sigmoid(action_logits)
        actions_dist = Bernoulli(probs=action_probs)
        actions = actions_dist.sample().detach().cpu().numpy()

        # set camera predictions
        action_probs_ud = torch.softmax(camera_logits[:, :self.n_classes], dim=1)
        action_probs_lr = torch.softmax(camera_logits[:, self.n_classes:], dim=1)
        proj_values_ud = self._soft_prediction_to_value(action_probs_ud, deterministic)
        proj_values_lr = self._soft_prediction_to_value(action_probs_lr, deterministic)

        # prepare agent actions
        env_actions = []
        for b in range(actions.shape[0]):
            env_action = deepcopy(template)
            for i, k in enumerate(BINARY_ACTIONS):
                env_action[k] = int(actions[b, i])
            env_action['camera'] = np.asarray([proj_values_ud[b], proj_values_lr[b]], np.float32)
            env_actions.append(env_action)

        return env_actions

    def _to_categorical_target(self, value, soft=True):

        # find covered bin
        idx = np.searchsorted(self.bins, value)
        target = np.full((self.n_classes,), fill_value=0.0, dtype=np.float32)

        # compile soft target vector
        if soft:
            if idx == 0 or idx == (self.n_classes - 1):
                target[idx] = 1
            else:
                dists = np.abs(value - self.bin_centers)
                sorted_idxs = np.argsort(dists)
                cum_dist = dists[sorted_idxs[0]] + dists[sorted_idxs[1]]
                target[sorted_idxs[0]] = dists[sorted_idxs[1]] / cum_dist
                target[sorted_idxs[1]] = dists[sorted_idxs[0]] / cum_dist
        else:
            target[idx] = 1

        return target

    def _soft_prediction_to_value(self, prediction, deterministic=False):

        if deterministic:

            # sample mode of distribution

            # consider mode of distribution
            prediction = prediction.detach().cpu().numpy()
            mode_idxs = np.argmax(prediction, axis=1)

            # consider also neighbouring bins
            values = []
            for b, mode_idx in enumerate(mode_idxs):
                indices = [mode_idx]
                if mode_idx > 0:
                    indices.insert(0, mode_idx - 1)
                if mode_idx < (self.n_classes - 2):
                    indices.append(mode_idx + 1)

                # weighted interpolation of value
                value = np.sum([self.bin_centers[i] * prediction[b, i] for i in indices])
                values.append(value)

            values = np.vstack(values).astype(np.float32)

            # take expected value of distribution
            # prediction = prediction.detach().cpu().numpy()
            # values = np.sum(self.bin_centers * prediction, axis=1, keepdims=True)

        # sample random value
        else:
            indices = Categorical(prediction).sample().detach().cpu().numpy()
            values = self.bin_centers[indices]

        return values


class MultiBinarySoftmaxCameraEnumActions(ActionSpace):
    def __init__(self, bins):
        super().__init__()
        self.bins = bins
        self.n_classes = len(self.bins) + 1
        self.bin_centers = [self.bins[0]] + \
                           [np.mean([self.bins[i:i + 2]]) for i in range(len(self.bins) - 1)] + \
                           [self.bins[-1]]
        self.bin_centers = np.array(self.bin_centers, dtype=np.float32)

    def get_shape(self):
        return {"binary_actions": (8,),
                "camera_actions": (2 * self.n_classes,)}

    def prepare_input(self, binary_actions: np.ndarray, camera_actions: np.ndarray, enum_actions: np.ndarray,
                      inventory: np.ndarray, equipped_items: np.ndarray, rewards: np.ndarray,
                      steps_remaining: np.ndarray):

        # bin camera in both directions
        target_ud = self._to_categorical_target(camera_actions[-1, 0])
        target_lr = self._to_categorical_target(camera_actions[-1, 1])

        targets = dict()
        targets["binary_actions"] = binary_actions[-1]
        targets["camera_actions"] = np.concatenate((target_ud, target_lr), axis=0)

        for i, enum_act in enumerate(ENUM_ACTIONS):
            targets["act_" + enum_act] = enum_actions.astype(np.int)[-1, i].copy()

        return targets

    def logits_to_dict(self, template, output_dict, **kwargs):
        action_logits = output_dict["logits"]
        camera_logits = output_dict["camera"]
        equip_logits = output_dict["act_equip"]
        craft_logits = output_dict["act_craft"]
        nearbyCraft_logits = output_dict["act_nearbyCraft"]
        nearbySmelt_logits = output_dict["act_nearbySmelt"]
        place_logits = output_dict["act_place"]
        env_action = self._sample(template, action_logits, camera_logits, equip_logits, place_logits, craft_logits,
                                  nearbyCraft_logits, nearbySmelt_logits, deterministic=False)
        return env_action

    def evaluate_actions(self, output_dict, env_actions):
        action_logits = output_dict["logits"]
        camera_logits = output_dict["camera"]
        equip_logits = output_dict["act_equip"]
        craft_logits = output_dict["act_craft"]
        nearbyCraft_logits = output_dict["act_nearbyCraft"]
        nearbySmelt_logits = output_dict["act_nearbySmelt"]
        place_logits = output_dict["act_place"]

        # compute value output
        value = output_dict["value"] if "value" in output_dict else None

        # compute binary action log probabilities and entropy
        # ---------------------------------------------------
        discrete_action_matrix, camera_actions = stack_actions(env_actions, BINARY_ACTIONS)
        discrete_action_matrix = torch.from_numpy(discrete_action_matrix).to(action_logits.device)

        action_probs = torch.sigmoid(action_logits)
        action_dist = Bernoulli(probs=action_probs)
        action_log_probs = action_dist.log_prob(discrete_action_matrix)
        action_entropy = action_dist.entropy().mean()

        # compute enum action log probs and entropy
        # -----------------------------------------
        equip_log_probs, equip_entropy = log_probs_entropy(equip_logits, env_actions, "equip")
        place_log_probs, place_entropy = log_probs_entropy(place_logits, env_actions, "place")
        craft_log_probs, craft_entropy = log_probs_entropy(craft_logits, env_actions, "craft")
        nearbyCraft_log_probs, nearbyCraft_entropy = log_probs_entropy(nearbyCraft_logits, env_actions, "nearbyCraft")
        nearbySmelt_log_probs, nearbySmelt_entropy = log_probs_entropy(nearbySmelt_logits, env_actions, "nearbySmelt")

        # compute camera log probabilities and entropy
        # --------------------------------------------
        disc_cam_actions_ud = np.array([self._to_categorical_target(v, soft=False).argmax()
                                        for v in camera_actions[:, 0]])
        disc_cam_actions_lr = np.array([self._to_categorical_target(v, soft=False).argmax()
                                        for v in camera_actions[:, 1]])
        disc_cam_actions_ud = torch.from_numpy(disc_cam_actions_ud).to(action_logits.device)
        disc_cam_actions_lr = torch.from_numpy(disc_cam_actions_lr).to(action_logits.device)

        camera_probs_ud = torch.softmax(camera_logits[:, :self.n_classes], dim=1)
        camera_probs_lr = torch.softmax(camera_logits[:, self.n_classes:], dim=1)

        camera_dist_ud = Categorical(probs=camera_probs_ud)
        camera_dist_lr = Categorical(probs=camera_probs_lr)

        camera_log_probs_ud = camera_dist_ud.log_prob(disc_cam_actions_ud).unsqueeze(-1)
        camera_log_probs_lr = camera_dist_lr.log_prob(disc_cam_actions_lr).unsqueeze(-1)
        camera_log_probs = torch.cat((camera_log_probs_ud, camera_log_probs_lr), dim=1)

        camera_entropy_ud = camera_dist_ud.entropy().mean()
        camera_entropy_lr = camera_dist_lr.entropy().mean()
        camera_entropy = (camera_entropy_ud + camera_entropy_lr).mean()

        return value, action_log_probs, camera_log_probs, \
               equip_log_probs, place_log_probs, craft_log_probs, nearbyCraft_log_probs, nearbySmelt_log_probs, \
               action_entropy, camera_entropy, \
               equip_entropy, place_entropy, craft_entropy, nearbyCraft_entropy, nearbySmelt_entropy

    def _sample(self, template, action_logits, camera_logits, equip_logits, place_logits, craft_logits,
                nearbyCraft_logits, nearbySmelt_logits, deterministic):

        # set action predictions
        action_probs = torch.sigmoid(action_logits)
        actions_dist = Bernoulli(probs=action_probs)
        actions = actions_dist.sample().detach().cpu().numpy()

        # enum action predictions
        equip_idx = self._sample_categorical(torch.softmax(equip_logits, dim=-1))
        place_idx = self._sample_categorical(torch.softmax(place_logits, dim=-1))
        craft_idx = self._sample_categorical(torch.softmax(craft_logits, dim=-1))
        nearbyCraft_idx = self._sample_categorical(torch.softmax(nearbyCraft_logits, dim=-1))
        nearbySmelt_idx = self._sample_categorical(torch.softmax(nearbySmelt_logits, dim=-1))

        # set camera predictions
        action_probs_ud = torch.softmax(camera_logits[:, :self.n_classes], dim=1)
        action_probs_lr = torch.softmax(camera_logits[:, self.n_classes:], dim=1)
        proj_values_ud = self._soft_prediction_to_value(action_probs_ud, deterministic)
        proj_values_lr = self._soft_prediction_to_value(action_probs_lr, deterministic)

        # prepare agent actions
        env_actions = []
        for b in range(actions.shape[0]):
            env_action = deepcopy(template)
            for i, k in enumerate(BINARY_ACTIONS):
                env_action[k] = int(actions[b, i])
            env_action['equip'] = ENUM_ACTION_OPTIONS['equip'][equip_idx[b]]
            env_action['place'] = ENUM_ACTION_OPTIONS['place'][place_idx[b]]
            env_action['craft'] = ENUM_ACTION_OPTIONS['craft'][craft_idx[b]]
            env_action['nearbyCraft'] = ENUM_ACTION_OPTIONS['nearbyCraft'][nearbyCraft_idx[b]]
            env_action['nearbySmelt'] = ENUM_ACTION_OPTIONS['nearbySmelt'][nearbySmelt_idx[b]]
            env_action['camera'] = np.asarray([proj_values_ud[b], proj_values_lr[b]], np.float32)
            env_actions.append(env_action)

        return env_actions

    def _to_categorical_target(self, value, soft=True):

        # find covered bin
        idx = np.searchsorted(self.bins, value)
        target = np.full((self.n_classes,), fill_value=0.0, dtype=np.float32)

        # compile soft target vector
        if soft:
            if idx == 0 or idx == (self.n_classes - 1):
                target[idx] = 1
            else:
                dists = np.abs(value - self.bin_centers)
                sorted_idxs = np.argsort(dists)
                cum_dist = dists[sorted_idxs[0]] + dists[sorted_idxs[1]]
                target[sorted_idxs[0]] = dists[sorted_idxs[1]] / cum_dist
                target[sorted_idxs[1]] = dists[sorted_idxs[0]] / cum_dist
        else:
            target[idx] = 1

        return target

    def _to_one_hot(self, value, enum_key):
        zeros = np.zeros((len(ENUM_ACTION_OPTIONS[enum_key]),), dtype=np.float32)
        zeros[value] = 1
        return zeros

    def _sample_categorical(self, probs):
        return Categorical(probs).sample().detach().cpu().numpy()

    def _soft_prediction_to_value(self, prediction, deterministic=False):

        # # consider mode of distribution
        # mode_idx = np.argmax(prediction)
        #
        # # consider also neighbouring bins
        # indices = [mode_idx]
        # if mode_idx > 0:
        #     indices.insert(0, mode_idx - 1)
        # if mode_idx < (self.n_classes - 2):
        #     indices.append(mode_idx + 1)
        #
        # # weighted interpolation of value
        # value = np.sum([self.bin_centers[i] * prediction[i] for i in indices])

        # take expected value of distribution
        if deterministic:
            prediction = prediction.detach().cpu().numpy()
            values = np.sum(self.bin_centers * prediction, axis=1, keepdims=True)
        # sample random value
        else:
            indices = Categorical(prediction).sample().detach().cpu().numpy()
            values = self.bin_centers[indices]

        return values


class MultiBinarySoftmaxCameraEnumActionsValueFunction(MultiBinarySoftmaxCameraEnumActions):
    def __init__(self, bins, value_before_end=10, value_scale=1.0):
        super().__init__(bins)
        self.value_before_end = value_before_end
        self.value_scale = value_scale

    def prepare_input(self, binary_actions: np.ndarray, camera_actions: np.ndarray, enum_actions: np.ndarray,
                      inventory: np.ndarray, equipped_items: np.ndarray, rewards: np.ndarray,
                      steps_remaining: np.ndarray):
        targets = super().prepare_input(binary_actions, camera_actions, enum_actions, inventory,
                                        equipped_items, rewards, steps_remaining)
        targets["value_function"] = (steps_remaining <= self.value_before_end).astype(np.float32) * self.value_scale
        return targets


class MultiBinaryWithCameraRegression(ActionSpace):
    def get_shape(self):
        return {"binary_actions": (8,),
                "camera_actions": (2,),
                "values": (1,)}

    def prepare_input(self, binary_actions: np.ndarray, camera_actions: np.ndarray, values: np.ndarray,
                      inventory: np.ndarray, equiped_items: np.ndarray):
        """
        Prepare targets for model optimization.

        :param binary_actions: 8
        :param camera_actions: 2
        :param values: 1
        :return:
        """
        targets = dict()
        targets["binary_actions"] = binary_actions[-1]
        targets["camera_actions"] = camera_actions[-1]
        targets["values"] = np.sum(values, axis=0)[np.newaxis, ...]

        return targets

    def logits_to_dict(self, template, output_dict, **kwargs):
        action_logits = output_dict["logits"]
        camera_logits = output_dict["camera"]

        # set action predictions
        env_actions = []
        for b in range(action_logits.shape[0]):
            env_action = deepcopy(template)
            action_probs = torch.sigmoid(action_logits[b])
            actions_dist = Bernoulli(probs=action_probs)
            actions = actions_dist.sample().squeeze().detach().cpu().numpy()
            for i, k in enumerate(BINARY_ACTIONS):
                env_action[k] = int(actions[i])

            # set camera predictions
            cam = camera_logits[b].detach().squeeze().cpu().numpy()
            env_action['camera'] = [cam[0], cam[1]]
            env_actions.append(env_action)
        return env_actions


class MultiBinaryBinnedCamera(ActionSpace):
    def __init__(self, bins):
        super().__init__()
        self.bins = bins
        self.n_classes = len(self.bins) + 1
        self.bin_centers = [self.bins[0]] + \
                           [np.mean([self.bins[i:i + 2]]) for i in range(len(self.bins) - 1)] + \
                           [self.bins[-1]]
        self.bin_centers = np.asarray(self.bin_centers, dtype=np.float32)

    def get_shape(self):
        return {"binary_actions": (8,),
                "camera_actions": (2 * self.n_classes,)}

    def prepare_input(self, binary_actions: np.ndarray, camera_actions: np.ndarray, values: np.ndarray,
                      inventory: np.ndarray, equiped_items: np.ndarray):
        # bin camera in both directions
        target_ud = np.full((self.n_classes,), fill_value=0.0, dtype=np.float32)
        idx_ud = np.searchsorted(self.bins, camera_actions[-1, 0])
        target_ud[idx_ud] = 1.0

        target_lr = np.full((self.n_classes,), fill_value=0.0, dtype=np.float32)
        idx_lr = np.searchsorted(self.bins, camera_actions[-1, 1])
        target_lr[idx_lr] = 1.0

        targets = dict()
        targets["binary_actions"] = binary_actions[-1]
        targets["camera_actions"] = np.concatenate((target_ud, target_lr), axis=0)

        return targets

    def logits_to_dict(self, template, output_dict, **kwargs):
        action_logits = output_dict["logits"]
        camera_logits = output_dict["camera"]

        # set action predictions
        action_probs = torch.sigmoid(action_logits)
        actions_dist = Bernoulli(probs=action_probs)
        actions = actions_dist.sample().squeeze().detach().cpu().numpy()
        for i, k in enumerate(BINARY_ACTIONS):
            template[k] = int(actions[i])

        # set camera predictions
        # TODO: watch out no sampling so far!
        camera_probs = torch.sigmoid(camera_logits).squeeze().detach().cpu().numpy()
        proj_value_ud = np.sum(self.bin_centers * camera_probs[:self.n_classes])
        proj_value_lr = np.sum(self.bin_centers * camera_probs[self.n_classes:])
        template['camera'] = np.asarray([proj_value_ud, proj_value_lr], np.float32)
        return template


class MultiBinarySoftmaxCameraInventory(MultiBinarySoftmaxCamera):
    def prepare_input(self, binary_actions: np.ndarray, camera_actions: np.ndarray, values: np.ndarray,
                      inventory: np.ndarray, equiped_items: np.ndarray):
        targets = super().prepare_input(binary_actions, camera_actions, values, inventory)
        targets["inventory"] = np.maximum(np.minimum(inventory[:, -1] - inventory[:, -2], 1), 0)
        return targets


class MultiBinarySoftmaxCameraEquip(MultiBinarySoftmaxCamera):
    def prepare_input(self, binary_actions: np.ndarray, camera_actions: np.ndarray, values: np.ndarray,
                      inventory: np.ndarray, equiped_items: np.ndarray):
        targets = super().prepare_input(binary_actions, camera_actions, values, inventory)
        targets["inventory"] = np.maximum(np.minimum(inventory[:, -1] - inventory[:, -2], 1), 0)
        return targets


class MultiStepMultiBinarySoftmaxCamera(MultiBinarySoftmaxCamera):
    def __init__(self, bins, action_context: int = 0):
        super().__init__(bins)
        self.action_context = action_context

    def prepare_input(self, binary_actions: np.ndarray, camera_actions: np.ndarray, values: np.ndarray,
                      inventory: np.ndarray, equiped_items: np.ndarray):
        # bin camera in both directions
        n_steps = 2 * self.action_context + 1
        targets_ud = np.zeros((n_steps, self.n_classes), dtype=np.float32)
        targets_lr = np.zeros((n_steps, self.n_classes), dtype=np.float32)
        for i, offset in enumerate(range(-self.action_context, self.action_context + 1)):
            target_idx = (-1 - self.action_context) + offset
            targets_ud[i, :] = self._to_categorical_target(camera_actions[target_idx, 0])
            targets_lr[i, :] = self._to_categorical_target(camera_actions[target_idx, 1])

        targets = dict()
        target_range = slice((-1 - (2 * self.action_context)), None)
        targets["binary_actions"] = binary_actions[target_range]
        targets["camera_actions"] = np.concatenate((targets_ud, targets_lr), axis=1)

        return targets

    def logits_to_dict(self, template, output_dict, **kwargs):
        action_logits = output_dict["logits"]
        camera_logits = output_dict["camera"]

        # select central value
        target_idx = (-1 - self.action_context)
        action_logits = action_logits[:, target_idx, :]
        camera_logits = camera_logits[:, target_idx, :]

        env_action = self._sample(template, action_logits, camera_logits, deterministic=True)
        return env_action


class SoftmaxActionNoCamera(ActionSpace):
    def __init__(self, temperature=0.01, camera_magnitude=10):
        super().__init__()
        self.temperature = temperature
        self.camera_magnitude = camera_magnitude
        self.action_mapping = {
            0: ["forward", "attack"],
            1: ["forward", "jump"],
            2: ["attack", "sneak"],
            3: ["attack"],
            4: ["attack", "camera_up"],
            5: ["attack", "camera_down"],
            6: ["attack", "camera_left"],
            7: ["attack", "camera_right"],
            8: ["forward"],
            9: ["forward", "camera_up"],
            10: ["forward", "camera_down"],
            11: ["forward", "camera_left"],
            12: ["forward", "camera_right"],
            13: ["camera_up"],
            14: ["camera_down"],
            15: ["camera_left"],
            16: ["camera_right"]
        }

    def get_shape(self):
        return {"categorical_actions": (len(self.action_mapping),)}

    def prepare_input(self, binary_actions: np.ndarray, camera_actions: np.ndarray, values: np.ndarray,
                      inventory: np.ndarray, equiped_items: np.ndarray):
        actions = np.zeros(self.get_shape()["categorical_actions"], dtype=np.float32)

        camera_none = np.logical_and(camera_actions[-1, 0] == 0, camera_actions[-1, 1] == 0)
        camera_up = camera_actions[-1, 0] == -self.camera_magnitude
        camera_down = camera_actions[-1, 0] == self.camera_magnitude
        camera_left = camera_actions[-1, 1] == -self.camera_magnitude
        camera_right = camera_actions[-1, 1] == self.camera_magnitude
        atk = binary_actions[-1, 0]
        fwd = binary_actions[-1, 2]
        # actions
        actions[0] = np.logical_and(fwd, atk)  # forward + attack
        actions[1] = np.logical_and(fwd, binary_actions[-1, 3])  # forward + jump
        actions[2] = np.logical_and(atk, binary_actions[-1, 6])  # attack + sneak
        actions[3] = np.logical_and(np.logical_and(atk, camera_none),
                                    np.logical_not(actions[0]))  # attack + no camera
        actions[4] = np.logical_and(atk, camera_up)  # attack + up
        actions[5] = np.logical_and(atk, camera_down)  # attack + down
        actions[6] = np.logical_and(atk, camera_left)  # attack + left
        actions[7] = np.logical_and(atk, camera_right)  # attack + right
        actions[8] = np.logical_and(np.logical_and(fwd, camera_none),
                                    np.logical_not(actions[0]))  # forward + no camera
        actions[9] = np.logical_and(fwd, camera_up)  # forward + up
        actions[10] = np.logical_and(fwd, camera_down)  # forward + down
        actions[11] = np.logical_and(fwd, camera_left)  # forward + left
        actions[12] = np.logical_and(fwd, camera_right)  # forward + right
        actions[13] = np.logical_and(camera_up, np.logical_not(np.logical_or(atk, fwd)))
        actions[14] = np.logical_and(camera_down, np.logical_not(np.logical_or(atk, fwd)))
        actions[15] = np.logical_and(camera_left, np.logical_not(np.logical_or(atk, fwd)))
        actions[16] = np.logical_and(camera_right, np.logical_not(np.logical_or(atk, fwd)))
        # normalize
        normalize = np.sum(actions)
        if normalize == 0:
            normalize = 1
        actions /= normalize

        return {"categorical_actions": actions}

    def logits_to_dict(self, template, output_dict, **kwargs):
        logits = output_dict["logits"]
        a = sample_with_temperature(logits, temperature=self.temperature)
        actions = self.action_mapping[a]
        for action in actions:
            if "camera" in action:
                direction = action.split("_")[-1]
                if direction == "left":
                    template["camera"] = [0, -self.camera_magnitude]
                elif direction == "right":
                    template["camera"] = [0, self.camera_magnitude]
                elif direction == "up":
                    template["camera"] = [-self.camera_magnitude, 0]
                elif direction == "down":
                    template["camera"] = [self.camera_magnitude, 0]
            else:
                template[action] = 1
        return template


class SoftmaxActionsWithCameraRegression(ActionSpace):
    def __init__(self, temperature=0.01):
        super().__init__()
        self.temperature = temperature
        self.action_mapping = {
            0: ["attack"],
            1: ["forward"],
            2: ["forward", "attack"],
            3: ["forward", "jump"],
            4: ["attack", "sneak"]
        }

    def get_shape(self):
        return {"categorical_actions": (len(self.action_mapping),),
                "camera_actions": (2 * self.n_classes,)}

    def prepare_input(self, binary_actions: np.ndarray, camera_actions: np.ndarray, values: np.ndarray,
                      inventory: np.ndarray, equiped_items: np.ndarray):
        actions = np.zeros(self.get_shape()["categorical_actions"], dtype=np.float32)
        binary_actions = binary_actions.squeeze()  # S, 8
        atk = binary_actions[-1, 0]
        fwd = binary_actions[-1, 2]
        # actions
        actions[0] = atk
        actions[1] = fwd
        actions[2] = np.logical_and(fwd, atk)  # forward + attack
        actions[3] = np.logical_and(fwd, binary_actions[-1, 3])  # forward + jump
        actions[4] = np.logical_and(atk, binary_actions[-1, 6])  # attack + sneak
        # normalize
        normalize = np.expand_dims(actions.sum(axis=1), axis=-1)
        normalize[normalize == 0] = 1
        actions = actions / normalize
        return {"categorical_actions": actions, "camera_actions": camera_actions}

    def logits_to_dict(self, template, output_dict, **kwargs):
        action_logits = output_dict["logits"]
        camera_logits = output_dict["camera"]

        a = sample_with_temperature(action_logits, temperature=self.temperature)
        actions = self.action_mapping[a]
        for action in actions:
            template[action] = 1
        # set camera predictions
        cam = camera_logits.detach().squeeze().cpu().numpy()
        template['camera'] = [cam[0], cam[1]]
        return template


class SoftmaxActionsSoftmaxCamera(ActionSpace):
    def __init__(self, bins, temperature=0.01):
        super().__init__()
        self.temperature = temperature
        self.action_mapping = {
            0: ["attack"],
            1: ["forward"],
            2: ["forward", "attack"],
            3: ["forward", "jump"],
            4: ["attack", "sneak"],
            5: ["back"]
        }
        self.bins = bins
        self.n_classes = len(self.bins) + 1
        self.bin_centers = [self.bins[0]] + \
                           [np.mean([self.bins[i:i + 2]]) for i in range(len(self.bins) - 1)] + \
                           [self.bins[-1]]

    def get_shape(self):
        return {"categorical_actions": (len(self.action_mapping),),
                "camera_actions": (2,)}

    def prepare_input(self, binary_actions: np.ndarray, camera_actions: np.ndarray, values: np.ndarray,
                      inventory: np.ndarray, equiped_items: np.ndarray):
        actions = np.zeros((len(self.action_mapping)), dtype=np.float32)
        # print(binary_actions.shape)
        # print(camera_actions.shape)
        # print(actions.shape)
        atk = binary_actions[-1, 0]
        fwd = binary_actions[-1, 2]
        bck = binary_actions[-1, 1]
        jmp = binary_actions[-1, 3]
        snk = binary_actions[-1, 6]
        # actions
        actions[0] = atk
        actions[1] = fwd
        actions[2] = np.logical_and(fwd, atk)  # forward + attack
        actions[3] = np.logical_and(fwd, jmp)  # forward + jump
        actions[4] = np.logical_and(atk, snk)  # attack + sneak
        actions[5] = bck
        # normalize
        normalize = actions.sum()
        if normalize == 0:
            normalize = 1
        actions = actions / normalize
        # bin camera
        # bin camera in both directions
        target_ud = self._to_soft_target(camera_actions[-1, 0])
        target_lr = self._to_soft_target(camera_actions[-1, 1])

        return {"categorical_actions": actions, "camera_actions": np.concatenate((target_ud, target_lr), axis=0)}

    def logits_to_dict(self, template, output_dict, **kwargs):
        action_logits = output_dict["logits"].squeeze()
        camera_logits = output_dict["camera"].squeeze()

        a = sample_with_temperature(action_logits, temperature=self.temperature)
        actions = self.action_mapping[a]
        for action in actions:
            template[action] = 1

        # set camera predictions
        # TODO: watch out no sampling so far!
        action_probs_ud = torch.softmax(camera_logits[:self.n_classes], dim=-1)
        action_probs_lr = torch.softmax(camera_logits[self.n_classes:], dim=-1)
        action_probs_ud = action_probs_ud.detach().cpu().numpy()
        action_probs_lr = action_probs_lr.detach().cpu().numpy()

        proj_value_ud = self._soft_prediction_to_value(action_probs_ud)
        proj_value_lr = self._soft_prediction_to_value(action_probs_lr)
        template['camera'] = np.asarray([proj_value_ud, proj_value_lr], np.float32)
        return template

    def _to_soft_target(self, value):

        # find covered bin
        idx = np.searchsorted(self.bins, value)
        target = np.full((self.n_classes,), fill_value=0.0, dtype=np.float32)

        # compile soft target vector
        if idx == 0 or idx == (self.n_classes - 1):
            target[idx] = 1
        else:
            dists = np.abs(value - self.bin_centers)
            sorted_idxs = np.argsort(dists)
            cum_dist = dists[sorted_idxs[0]] + dists[sorted_idxs[1]]
            target[sorted_idxs[0]] = dists[sorted_idxs[1]] / cum_dist
            target[sorted_idxs[1]] = dists[sorted_idxs[0]] / cum_dist

        return target

    def _soft_prediction_to_value(self, prediction):

        # # consider mode of distribution
        # mode_idx = np.argmax(prediction)
        #
        # # consider also neighbouring bins
        # indices = [mode_idx]
        # if mode_idx > 0:
        #     indices.insert(0, mode_idx - 1)
        # if mode_idx < (self.n_classes - 2):
        #     indices.append(mode_idx + 1)
        #
        # # weighted interpolation of value
        # value = np.sum([self.bin_centers[i] * prediction[i] for i in indices])

        # consider all bins
        value = np.sum(self.bin_centers * prediction)

        return value


class MultiBinaryIncludingCamera(ActionSpace):
    def __init__(self):
        super().__init__()
        self.action_mapping = [
            ["forward", "attack"],
            ["forward", "jump"],
            ["attack", "sneak"],
            ["attack"],
            ["attack", "camera_left"],
            ["attack", "camera_right"],
            ["forward"],
            ["forward", "camera_left"],
            ["forward", "camera_right"],
            ["camera_left"],
            ["camera_right"]
        ]

    def get_shape(self):
        return {"binary_actions": (11,),
                "camera_actions": (0,)}

    def prepare_input(self, binary_actions: np.ndarray, camera_actions: np.ndarray, values: np.ndarray,
                      inventory: np.ndarray, equiped_items: np.ndarray):

        targets = dict()
        targets["binary_actions"] = None
        targets["camera_actions"] = None

        return targets

    def logits_to_dict(self, template, output_dict, **kwargs):
        action_logits = output_dict["logits"]
        env_action = self._sample(template, action_logits, None)
        return env_action

    def evaluate_actions(self, output_dict, env_actions):
        action_logits = output_dict["logits"]

        # compute value output
        value = output_dict["value"] if "value" in output_dict else None

        # compute binary action log probabilities and entropy
        # ---------------------------------------------------
        discrete_action_matrix = self._stack_actions(env_actions)
        discrete_action_matrix = torch.from_numpy(discrete_action_matrix).to(action_logits.device)

        action_probs = torch.sigmoid(action_logits)
        action_dist = Bernoulli(probs=action_probs)
        action_log_probs = action_dist.log_prob(discrete_action_matrix)
        action_entropy = action_dist.entropy().mean()

        # set some dummy values for camera arrays (not required for this action space)
        camera_log_probs = torch.zeros((action_log_probs.shape[0], 0), device=action_log_probs.device)
        camera_entropy = torch.zeros((action_log_probs.shape[0], 1), device=action_log_probs.device).mean()

        return value, action_log_probs, camera_log_probs, action_entropy, camera_entropy

    def _sample(self, template, action_logits, camera_logits):

        # set action predictions
        action_probs = torch.sigmoid(action_logits)
        actions_dist = Bernoulli(probs=action_probs)
        actions = actions_dist.sample().detach().cpu().numpy()

        # prepare agent actions
        env_actions = []
        for b in range(actions.shape[0]):
            env_action = deepcopy(template)
            camera_left = 0
            camera_right = 0
            for i, action_list in enumerate(self.action_mapping):
                for action in action_list:
                    if action == "camera_left":
                        camera_left |= int(actions[b, i])
                    elif action == "camera_right":
                        camera_right |= int(actions[b, i])
                    else:
                        env_action[action] |= int(actions[b, i])

            # map camera action
            cam_lr = (10 * camera_right) + (-10 * camera_left)
            env_action['camera'] = np.asarray([0, cam_lr], np.float32)
            env_actions.append(env_action)

        return env_actions

    def _stack_actions(self, env_actions):
        """
        Stack list of action dictionaries to binary matrix with shape (T, A)
            env_actions: list of action dictionaries
        """

        n_actions = len(self.action_mapping)
        n_steps = len(env_actions)

        action_vec = np.zeros((n_steps, n_actions), dtype=np.float32)
        for i, env_action in enumerate(env_actions):
            for j, action_list in enumerate(self.action_mapping):
                all_true = True
                for action in action_list:
                    if action == "camera_left":
                        all_true &= env_action["camera"].flatten()[1] == -10
                    elif action == "camera_right":
                        all_true &= env_action["camera"].flatten()[1] == 10
                    else:
                        all_true &= env_action[action]
                action_vec[i, j] = int(all_true)
        return action_vec


class SingleCategorical(ActionSpace):
    def __init__(self):
        super().__init__()
        self.action_mapping = [
            ["forward", "attack"],
            ["forward", "jump"],
            ["attack", "sneak"],
            ["attack"],
            ["attack", "camera_left"],
            ["attack", "camera_right"],
            ["forward"],
            ["forward", "camera_left"],
            ["forward", "camera_right"],
            ["camera_left"],
            ["camera_right"]
        ]
        self.cam_angle = 10

    def get_shape(self):
        return {"binary_actions": (11,),
                "camera_actions": (0,)}

    def prepare_input(self, binary_actions: np.ndarray, camera_actions: np.ndarray, values: np.ndarray,
                      inventory: np.ndarray, equiped_items: np.ndarray):

        targets = dict()
        targets["binary_actions"] = None
        targets["camera_actions"] = None

        return targets

    def logits_to_dict(self, template, output_dict, **kwargs):
        action_logits = output_dict["logits"]
        env_action = self._sample(template, action_logits, None)
        return env_action

    def evaluate_actions(self, output_dict, env_actions):
        action_logits = output_dict["logits"]

        # compute value output
        value = output_dict["value"] if "value" in output_dict else None

        # compute action log probabilities and entropy
        action_vector = self._stack_actions(env_actions)
        action_vector = torch.from_numpy(action_vector).to(action_logits.device)

        action_probs = torch.softmax(action_logits, dim=1)
        action_dist = Categorical(probs=action_probs)
        action_log_probs = action_dist.log_prob(action_vector).unsqueeze(-1)
        action_entropy = action_dist.entropy().mean()

        # set some dummy values for camera arrays (not required for this action space)
        camera_log_probs = torch.zeros((action_log_probs.shape[0], 0), device=action_log_probs.device)
        camera_entropy = torch.tensor(0, dtype=torch.float32, device=action_logits.device)

        return value, action_log_probs, camera_log_probs, action_entropy, camera_entropy

    def _sample(self, template, action_logits, camera_logits):

        # set action predictions
        action_probs = torch.softmax(action_logits, dim=1)
        actions_dist = Categorical(probs=action_probs)
        actions = actions_dist.sample().detach().cpu().numpy()

        # prepare agent actions
        env_actions = []
        for b in range(actions.shape[0]):
            env_action = deepcopy(template)

            for action in self.action_mapping[actions[b]]:

                if action == "camera_left":
                    env_action["camera"][1] = -self.cam_angle
                elif action == "camera_right":
                    env_action["camera"][1] = self.cam_angle
                else:
                    env_action[action] = 1

            env_actions.append(env_action)

        return env_actions

    def _stack_actions(self, env_actions):
        """ Stack list of action dictionaries to indicator vector (T,)
            env_actions: list of action dictionaries
        """

        n_steps = len(env_actions)
        action_vec = np.full((n_steps,), fill_value=-1, dtype=np.float32)

        # iterate list of env actions
        for i, env_action in enumerate(env_actions):

            # iterate entries (action lists) of action mapping
            for j, action_list in enumerate(self.action_mapping):

                # check if entry was found
                all_true = True
                for action in action_list:
                    if action == "camera_left":
                        all_true &= env_action["camera"].flatten()[1] == -self.cam_angle
                    elif action == "camera_right":
                        all_true &= env_action["camera"].flatten()[1] == self.cam_angle
                    else:
                        all_true &= env_action[action]

                if all_true:
                    action_vec[i] = j
                    break

        assert np.sum(action_vec[action_vec < 0]) == 0

        return action_vec
