import numpy as np
from train.behavioral_cloning.datasets.agent_state import AgentEpisode


class MineRLObtainDiamondPklPreprocessor:
    def __init__(self, seed=None):
        self.seed = seed

    def __call__(self, episode: AgentEpisode):
        seq_len = len(episode.sequence)
        assert seq_len > 0

        # define templates
        states = {
            "equipped_items": {
                "mainhand": {
                    "damage": [],
                    "maxDamage": [],
                    "type": []
                }
            },
            "inventory": {
                "coal": [],
                "cobblestone": [],
                "crafting_table": [],
                "dirt": [],
                "furnace": [],
                "iron_axe": [],
                "iron_ingot": [],
                "iron_ore": [],
                "iron_pickaxe": [],
                "log": [],
                "planks": [],
                "stick": [],
                "stone": [],
                "stone_axe": [],
                "stone_pickaxe": [],
                "torch": [],
                "wooden_axe": [],
                "wooden_pickaxe": []
            },
            "pov": []
        }
        actions = {
            "attack": [],
            "back": [],
            "camera": [],
            "craft": [],
            "equip": [],
            "forward": [],
            "jump": [],
            "left": [],
            "nearbyCraft": [],
            "nearbySmelt": [],
            "place": [],
            "right": [],
            "sneak": [],
            "sprint": []
        }
        rewards = []
        dones = []
        infos = []

        # reshape to correct sequence
        for step in episode.sequence:
            # handle state
            states["equipped_items"]["mainhand"]["damage"].append(
                step[0]["equipped_items"]["mainhand"]["damage"])
            states["equipped_items"]["mainhand"]["maxDamage"].append(
                step[0]["equipped_items"]["mainhand"]["maxDamage"])
            states["equipped_items"]["mainhand"]["type"].append(
                step[0]["equipped_items"]["mainhand"]["type"])
            states["inventory"]["coal"].append(
                step[0]["inventory"]["coal"])
            states["inventory"]["cobblestone"].append(
                step[0]["inventory"]["cobblestone"])
            states["inventory"]["crafting_table"].append(
                step[0]["inventory"]["crafting_table"])
            states["inventory"]["dirt"].append(
                step[0]["inventory"]["dirt"])
            states["inventory"]["furnace"].append(
                step[0]["inventory"]["furnace"])
            states["inventory"]["iron_axe"].append(
                step[0]["inventory"]["iron_axe"])
            states["inventory"]["iron_ingot"].append(
                step[0]["inventory"]["iron_ingot"])
            states["inventory"]["iron_ore"].append(
                step[0]["inventory"]["iron_ore"])
            states["inventory"]["iron_pickaxe"].append(
                step[0]["inventory"]["iron_pickaxe"])
            states["inventory"]["log"].append(
                step[0]["inventory"]["log"])
            states["inventory"]["planks"].append(
                step[0]["inventory"]["planks"])
            states["inventory"]["stick"].append(
                step[0]["inventory"]["stick"])
            states["inventory"]["stone"].append(
                step[0]["inventory"]["stone"])
            states["inventory"]["stone_axe"].append(
                step[0]["inventory"]["stone_axe"])
            states["inventory"]["stone_pickaxe"].append(
                step[0]["inventory"]["stone_pickaxe"])
            states["inventory"]["torch"].append(
                step[0]["inventory"]["torch"])
            states["inventory"]["wooden_axe"].append(
                step[0]["inventory"]["wooden_axe"])
            states["inventory"]["wooden_pickaxe"].append(
                step[0]["inventory"]["wooden_pickaxe"])
            states["pov"].append(
                step[0]["pov"])

            # handle actions
            actions["attack"].append(step[1]["attack"])
            actions["back"].append(step[1]["back"])
            actions["camera"].append(step[1]["camera"].flatten())
            actions["craft"].append(step[1]["craft"])
            actions["equip"].append(step[1]["equip"])
            actions["forward"].append(step[1]["forward"])
            actions["jump"].append(step[1]["jump"])
            actions["left"].append(step[1]["left"])
            actions["nearbyCraft"].append(step[1]["nearbyCraft"])
            actions["nearbySmelt"].append(step[1]["nearbySmelt"])
            actions["place"].append(step[1]["place"])
            actions["right"].append(step[1]["right"])
            actions["sneak"].append(step[1]["sneak"])
            actions["sprint"].append(step[1]["sprint"])

            # handle rewards
            rewards.append(step[2])

            # handle dones
            dones.append(step[3])

            # handle infos
            infos.append({})

        # stack results
        states["equipped_items"]["mainhand"]["damage"] = \
            np.stack(states["equipped_items"]["mainhand"]["damage"], axis=0).astype(np.int)
        states["equipped_items"]["mainhand"]["maxDamage"] = \
            np.stack(states["equipped_items"]["mainhand"]["maxDamage"], axis=0).astype(np.int)
        states["equipped_items"]["mainhand"]["damage"] = \
            np.stack(states["equipped_items"]["mainhand"]["damage"], axis=0).astype(np.int)
        states["equipped_items"]["mainhand"]["type"] = \
            np.stack(states["equipped_items"]["mainhand"]["type"], axis=0).astype(np.int)
        states["inventory"]["coal"] = \
            np.stack(states["inventory"]["coal"], axis=0).astype(np.int)
        states["inventory"]["cobblestone"] = \
            np.stack(states["inventory"]["cobblestone"], axis=0).astype(np.int)
        states["inventory"]["crafting_table"] = \
            np.stack(states["inventory"]["crafting_table"], axis=0).astype(np.int)
        states["inventory"]["dirt"] = \
            np.stack(states["inventory"]["dirt"], axis=0).astype(np.int)
        states["inventory"]["furnace"] = \
            np.stack(states["inventory"]["furnace"], axis=0).astype(np.int)
        states["inventory"]["iron_axe"] = \
            np.stack(states["inventory"]["iron_axe"], axis=0).astype(np.int)
        states["inventory"]["iron_ingot"] = \
            np.stack(states["inventory"]["iron_ingot"], axis=0).astype(np.int)
        states["inventory"]["iron_ore"] = \
            np.stack(states["inventory"]["iron_ore"], axis=0).astype(np.int)
        states["inventory"]["iron_pickaxe"] = \
            np.stack(states["inventory"]["iron_pickaxe"], axis=0).astype(np.int)
        states["inventory"]["log"] = \
            np.stack(states["inventory"]["log"], axis=0).astype(np.int)
        states["inventory"]["planks"] = \
            np.stack(states["inventory"]["planks"], axis=0).astype(np.int)
        states["inventory"]["stick"] = \
            np.stack(states["inventory"]["stick"], axis=0).astype(np.int)
        states["inventory"]["stone"] = \
            np.stack(states["inventory"]["stone"], axis=0).astype(np.int)
        states["inventory"]["stone_axe"] = \
            np.stack(states["inventory"]["stone_axe"], axis=0).astype(np.int)
        states["inventory"]["stone_pickaxe"] = \
            np.stack(states["inventory"]["stone_pickaxe"], axis=0).astype(np.int)
        states["inventory"]["torch"] = \
            np.stack(states["inventory"]["torch"], axis=0).astype(np.int)
        states["inventory"]["wooden_axe"] = \
            np.stack(states["inventory"]["wooden_axe"], axis=0).astype(np.int)
        states["inventory"]["wooden_pickaxe"] = \
            np.stack(states["inventory"]["wooden_pickaxe"], axis=0).astype(np.int)
        states["pov"] = np.stack(states["pov"], axis=0).astype(np.uint8)

        actions["attack"] = np.stack(actions["attack"], axis=0).astype(np.int)
        actions["back"] = np.stack(actions["back"], axis=0).astype(np.int)
        actions["camera"] = np.stack(actions["camera"], axis=0).astype(np.float32)
        actions["craft"] = np.stack(actions["craft"], axis=0).astype(np.int)
        actions["equip"] = np.stack(actions["equip"], axis=0).astype(np.int)
        actions["forward"] = np.stack(actions["forward"], axis=0).astype(np.int)
        actions["jump"] = np.stack(actions["jump"], axis=0).astype(np.int)
        actions["left"] = np.stack(actions["left"], axis=0).astype(np.int)
        actions["nearbyCraft"] = np.stack(actions["nearbyCraft"], axis=0).astype(np.int)
        actions["nearbySmelt"] = np.stack(actions["nearbySmelt"], axis=0).astype(np.int)
        actions["place"] = np.stack(actions["place"], axis=0).astype(np.int)
        actions["right"] = np.stack(actions["right"], axis=0).astype(np.int)
        actions["sneak"] = np.stack(actions["sneak"], axis=0).astype(np.int)
        actions["sprint"] = np.stack(actions["sprint"], axis=0).astype(np.int)

        rewards = np.stack(rewards, axis=0).astype(np.float32)
        dones = np.stack(dones, axis=0).astype(np.bool)
        infos = np.stack(infos, axis=0)

        return states, actions, rewards, dones, infos


class MineRLTreechopPklPreprocessor:
    def __init__(self, seed=None):
        self.seed = seed

    def __call__(self, episode: AgentEpisode):
        seq_len = len(episode.sequence)
        assert seq_len > 0

        # define templates
        states = {
            "pov": []
        }
        actions = {
            "attack": [],
            "back": [],
            "camera": [],
            "forward": [],
            "jump": [],
            "left": [],
            "right": [],
            "sneak": [],
            "sprint": []
        }
        rewards = []
        dones = []
        infos = []

        # reshape to correct sequence
        for step in episode.sequence:
            # handle state
            states["pov"].append(
                step[0]["pov"])

            # handle actions
            actions["attack"].append(step[1]["attack"])
            actions["back"].append(step[1]["back"])
            actions["camera"].append(step[1]["camera"].flatten())
            actions["forward"].append(step[1]["forward"])
            actions["jump"].append(step[1]["jump"])
            actions["left"].append(step[1]["left"])
            actions["right"].append(step[1]["right"])
            actions["sneak"].append(step[1]["sneak"])
            actions["sprint"].append(step[1]["sprint"])

            # handle rewards
            rewards.append(step[2])

            # handle dones
            dones.append(step[3])

            # handle infos
            infos.append({})

        # stack results
        states["pov"] = np.stack(states["pov"], axis=0).astype(np.uint8)

        actions["attack"] = np.stack(actions["attack"], axis=0).astype(np.int)
        actions["back"] = np.stack(actions["back"], axis=0).astype(np.int)
        actions["camera"] = np.stack(actions["camera"], axis=0).astype(np.float32)
        actions["forward"] = np.stack(actions["forward"], axis=0).astype(np.int)
        actions["jump"] = np.stack(actions["jump"], axis=0).astype(np.int)
        actions["left"] = np.stack(actions["left"], axis=0).astype(np.int)
        actions["right"] = np.stack(actions["right"], axis=0).astype(np.int)
        actions["sneak"] = np.stack(actions["sneak"], axis=0).astype(np.int)
        actions["sprint"] = np.stack(actions["sprint"], axis=0).astype(np.int)

        rewards = np.stack(rewards, axis=0).astype(np.float32)
        dones = np.stack(dones, axis=0).astype(np.bool)
        infos = np.stack(infos, axis=0)

        return states, actions, rewards, dones, infos