import numpy as np

from train.package_manager import PackageManager

OBSERVATION_EQUIP = ["none", "air", "wooden_axe", "wooden_pickaxe", "stone_axe", "stone_pickaxe", "iron_axe",
                     "iron_pickaxe", "other"]


class AgentEpisode(object):
    def __init__(self, preprocessor):
        self.sequences_length = 0
        self.sequence = []
        self.preprocessor = preprocessor

    def append(self, state, action, reward, done, info):
        """ Saves  """
        self.sequence.append([state, action, reward, done, info])
        self.sequences_length += 1

    def to_minerl_pkl_format(self):
        return self.preprocessor(self)


class AgentState(object):
    """
    Maintains a state in order to apply policy networks when interacting with environments.
    """

    def __init__(self, sequence_length, data_transform):
        self.sequence_length = sequence_length
        self.data_transform = data_transform
        self.required_sequence_length = self.sequence_length + 1
        self.sequence = []

    def update(self, state, action, reward, done, info):
        self.sequence.append([state, action, reward, done, info])

        seq_len = self.sequence_length if not PackageManager.get_instance().enabled() else PackageManager.get_instance().dataset.SEQ_LENGTH
        req_seq_len = self.required_sequence_length if not PackageManager.get_instance().enabled() else seq_len + 1
        if len(self.sequence) > req_seq_len:
            self.sequence = self.sequence[-req_seq_len::]

    def prepare_observation(self):
        """
        :returns
            pov: S, C, H, W
            discrete_action_matrix: S, 8
            camera_actions: S, 2
            rewards: S, 1
        """
        assert len(self.sequence) == self.required_sequence_length

        # state
        # {'equipped_items': {'mainhand': {'damage': 0, 'maxDamage': 0, 'type': 0}},
        # 'inventory': {'coal': 0, 'cobblestone': 0, 'crafting_table': 0, 'dirt': 0,
        #       'furnace': 0, 'iron_axe': 0, 'iron_ingot': 0, 'iron_ore': 0, 'iron_pickaxe': 0,
        #       'log': 0, 'planks': 0, 'stick': 0, 'stone': 0, 'stone_axe': 0,
        #       'stone_pickaxe': 0, 'torch': 0, 'wooden_axe': 0, 'wooden_pickaxe': 0},
        #  'pov': array([[[ 22,  44,  13], ...

        # prepare state input
        pov = np.zeros((self.sequence_length, 3, 64, 64), dtype=np.float32)
        inventory = np.zeros((self.sequence_length, 18), dtype=np.float32)
        equipped_items = np.zeros((self.sequence_length,), dtype=np.int16)
        for i in range(1, self.required_sequence_length):
            state = self.sequence[i][0]
            pov[i - 1] = np.transpose(state['pov'], (2, 0, 1)).astype(np.float32)
            if 'inventory' in state:
                inventory[i - 1] = list(state["inventory"].values())
            if 'equipped_items' in state:
                equipped_items[i - 1] = OBSERVATION_EQUIP.index(state["equipped_items"]["mainhand"]["type"])

        # TODO: implement this (keep in mind that actions and states are shifted by 1 step!!!)
        binary_actions = np.zeros((self.sequence_length, 8), dtype=np.float32)
        camera_actions = np.zeros((self.sequence_length, 2), dtype=np.float32)
        enum_actions = np.zeros((self.sequence_length, 5), dtype=np.float32)
        rewards = np.zeros((self.sequence_length, 1), dtype=np.float32)

        # apply data transforms
        data = pov, binary_actions, camera_actions, enum_actions, inventory, equipped_items, rewards
        data_transform = self.data_transform if not PackageManager.get_instance().enabled() else PackageManager.get_instance().dataset.DATA_TRANSFORM
        pov, binary_actions, camera_actions, enum_actions, inventory, equipped_items, rewards = data_transform(data)

        return pov, binary_actions, camera_actions, enum_actions, inventory, equipped_items, rewards
