#!/usr/bin/env python3
import fcntl
import getpass
import logging
import os
import queue
import subprocess
import tempfile
import threading
from copy import deepcopy
import numpy as np
import coloredlogs
import gym
import random

from train.behavioral_cloning.datasets.preprocessor import MineRLObtainDiamondPklPreprocessor, \
    MineRLTreechopPklPreprocessor
from train.experience_replay.experience_wrappers import PartialExperienceSamplingWrapper
from train.package_manager import PackageManager

coloredlogs.install(logging.INFO)

from train.baselines.base_vec_env import VecEnv
from train.baselines.subproc_vec_env import SubprocVecEnv, _worker
from train.behavioral_cloning.datasets.agent_state import AgentState, AgentEpisode
from train.behavioral_cloning.datasets.experience import Experience
from gym.wrappers import TimeLimit
from train.baselines.monitor import Monitor
from train.enums import ExecMode


def build_env(env_id, seq_len, transforms, input_space, use_reprio=False, env_server=False, host='127.0.0.1', port=9999,
              thread_id=0, seed=None, experience_recording=None, replay_until=None, experience_folder=None, checkpoint=None,
              frame_skip=None, make_new_recording=False, exec_mode: ExecMode = ExecMode.Train,
              craft_equip=True, monitor_equip=True, generate_experience_video=False, debug_video='env'):
    logging.info(f'Creating environment: {env_id}')
    if env_server:
        env = gym_remote_sync_create(env_id, host, port)
    else:
        env = gym_sync_create(env_id, thread_id)

    env = SafeActionsWrapper(env, env_id)

    env = SaveStepWrapper(env)

    if experience_recording:
        if seed is not None:
            logging.info(
                'Using experience playback. The seed given ({}) will be ignored. (The recording only works on the '
                'seed it was recorded on, which is {})'.format(seed, experience_recording.meta_info))
        assert checkpoint is not None, "If using an experience_recording, you must also set a checkpoint!"
        valid_checkpoints = list(experience_recording.checkpoints.keys())
        valid_checkpoints.append('latest')
        assert checkpoint in valid_checkpoints, "{} is not a checkpoint present in the experience recording!".format(
            checkpoint)

        env = ExperiencePlaybackWrapper(
            env=env,
            experience_recording=experience_recording,
            checkpoint=checkpoint
        )

    elif experience_folder:
        if replay_until is 'old_version':
            # read in the experiences in the folder given and playback a randomly chosen experience from the folder
            env = ExperienceSamplingWrapper(
                env=env,
                experience_folder=experience_folder
            )
        else:
            env = PartialExperienceSamplingWrapper(
                env=env,
                experience_folder=experience_folder,
                replay_until=replay_until,
                verbose=False,
                generate_experience_video=generate_experience_video,
                debug_video=debug_video
            )

    else:
        if not seed:
            logging.warning(
                "Warning! You must set an initial seed when not replaying a recorded experience. Using initial_seed=42.")
            seed = 42
        else:
            logging.info(f"Recording new experience from env with seed {seed}")
        env = InitialSeedWrapper(env=env, initial_seed=seed)

    if make_new_recording and not experience_recording and not experience_folder:
        env = ExperienceRecordingWrapper(env=env,
                                         save_path='tmp/experiences/bc_agent/tests_for_master')
    elif (make_new_recording and experience_recording) or (make_new_recording and experience_folder):
        env = ExperienceRecordingWrapper(env=env,
                                         save_path='tmp/experiences/bc_agent/tests_for_master',
                                         extend_experience=True, experience_recording=experience_recording)
    else:
        # in this case, we do not use the ExperienceRecordingWrapper at all
        pass

    if frame_skip is not None:
        env = FrameSkip(env, skip=frame_skip)

    if craft_equip:
        env = CraftEquipWrapper(env)

    if monitor_equip:
        env = MonitorEquipWrapper(env)

    env = reprioritze_env(env, use_reprio)

    env = MetaControllerWrapper(env)

    env = AgentStateWrapper(env, seq_len, transforms, input_space)

    env = Monitor(env, None, allow_early_resets=True)

    if make_new_recording:
        env = ExposeRecordedExperienceWrapper(env=env)
    elif experience_recording:
        env = ExposeExperiencePlaybackWrapper(env=env)

    return env


def make_minerl(env_id, n_cpu, seq_len, transforms, input_space, use_reprio=False, env_server=False, host='127.0.0.1',
                port=9999, seed=None, experience_recording=None, replay_until=None, checkpoint=None, frame_skip=None,
                make_new_recording=False, experience_folder=None, exec_mode: ExecMode = ExecMode.Train,
                craft_equip=True, monitor_equip=True, generate_experience_video=False, debug_video='env'):
    return ThreadedVecEnv(
        [lambda: build_env(env_id, seq_len, transforms, input_space, use_reprio=use_reprio, env_server=env_server,
                           host=host, port=port, thread_id=i, seed=seed, experience_recording=experience_recording,
                           replay_until=replay_until,
                           checkpoint=checkpoint, frame_skip=frame_skip, make_new_recording=make_new_recording,
                           experience_folder=experience_folder, exec_mode=exec_mode,
                           craft_equip=craft_equip, monitor_equip=monitor_equip,
                           generate_experience_video=generate_experience_video,
                           debug_video=debug_video
                           ) for i in range(n_cpu)]
    )


# @deprecated
def make_minerl_time_limit(env_id, n_cpu, seq_len, transforms, input_space, use_reprio=False, time_limit=18000):
    return ThreadedVecEnv(
        [lambda: TimeLimit(
            Monitor(
                AgentStateWrapper(
                    RecordingWrapper(
                        reprioritze_env(
                            gym_sync_create(env_id, i),
                            use_reprio
                        ),
                        env_id),
                    seq_len,
                    transforms,
                    input_space),
                None,
                allow_early_resets=True
            ),
            max_episode_steps=time_limit * (i + 1)) for i in range(n_cpu)
         ]
    )


class QueuePipe:
    def __init__(self, send_q, req_q):
        self._send_q = send_q
        self._req_q = req_q

    def recv(self):
        return self._req_q.get()

    def send(self, value):
        return self._send_q.put(value)

    def close(self):
        pass


class ThreadedVecEnv(SubprocVecEnv):
    def __init__(self, env_fns):
        self.waiting = False
        self.closed = False
        n_envs = len(env_fns)

        self.remotes, self.work_remotes = zip(
            *[(lambda x, y: (QueuePipe(x, y), QueuePipe(y, x)))(queue.Queue(1), queue.Queue(1)) for _ in range(n_envs)])
        self.processes = []
        for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns):
            env_fn.var = env_fn
            args = (work_remote, remote, env_fn)
            # daemon=True: if the main process crashes, we should not cause things to hang
            process = threading.Thread(target=_worker, args=args, daemon=True)
            process.start()
            self.processes.append(process)

        self.remotes[0].send(('get_spaces', None))
        observation_space, action_space = self.remotes[0].recv()
        VecEnv.__init__(self, len(env_fns), observation_space, action_space)


def gym_sync_create(env_string, thread_id):
    lock_dir = os.path.join(tempfile.gettempdir(), getpass.getuser())
    if not os.path.exists(lock_dir):
        os.makedirs(lock_dir)
    with open(os.path.join(lock_dir, "minecraft-{}.lock".format(thread_id)), "wb") as lock_file:
        fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
        try:
            env = gym.make(env_string)
            return env
        finally:
            fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)


def gym_remote_sync_create(env_string, host, port):
    from train.envs.env_client import RemoteGym
    gym = RemoteGym(host, port)
    env = gym.make(env_string)
    return env


def reprioritze_env(env, use_reprio):
    if use_reprio:
        import psutil
        current_process = psutil.Process(env.instance.minecraft_process.pid)
        current_process.nice(19)
        children = current_process.children(recursive=True)
        for child in children:
            for t in child.threads():
                subprocess.call(["renice", "19", str(t.id)])
    return env


class SafeActionsWrapper(gym.Wrapper):
    def __init__(self, env, env_name):
        super(SafeActionsWrapper, self).__init__(env)
        self.env_name = env_name

    def step(self, action):
        if self.env_name == "MineRLTreechop-v0":
            if 'place' in action:
                del action['place']
            if 'equip' in action:
                del action['equip']
            if 'craft' in action:
                del action['craft']
            if 'nearbyCraft' in action:
                del action['nearbyCraft']
            if 'nearbySmelt' in action:
                del action['nearbySmelt']
        state, reward, done, info = super().step(action)
        return state, reward, done, info


class MetaControllerWrapper(gym.Wrapper):
    def __init__(self, env):
        super(MetaControllerWrapper, self).__init__(env)

    def step(self, action):
        state, reward, done, info = super().step(action)
        info['meta_controller'] = {
            'state': state,
            'action': action,
            'reward': reward,
            'done': done
        }
        return state, reward, done, info


class AgentStateWrapper(gym.Wrapper):
    def __init__(self, env, seq_len, transforms, input_space):
        super().__init__(env)
        self.observation_space = None
        self.seq_len = seq_len
        self.input_space = input_space
        self.sequential_state = AgentState(sequence_length=seq_len, data_transform=transforms)

    def _prepare_observation(self):
        pov, binary_actions, camera_actions, enum_actions, inventory, equipped_items, rewards = self.sequential_state.prepare_observation()
        input_space = self.input_space if not PackageManager.get_instance().enabled() else PackageManager.get_instance().dataset.INPUT_SPACE
        input_dict = input_space.prepare(pov, binary_actions, camera_actions, enum_actions, inventory,
                                         equipped_items, rewards, np.zeros_like(rewards))
        return input_dict

    def step(self, action):
        state, reward, done, info = super().step(action)
        self.sequential_state.update(state=state, action=action, reward=reward, done=done, info=info)
        return self._prepare_observation(), reward, done, info

    def reset(self, **kwargs):
        super().reset()
        for _ in range(self.sequential_state.required_sequence_length):
            action = self.env.action_space.noop()
            action["jump"] = 1
            state, reward, done, info = self.env.step(action)
            self.sequential_state.update(state, action, reward, done, info)
        return self._prepare_observation()


class ExposeRecordedExperienceWrapper(gym.Wrapper):
    def __init__(self, env):
        super(ExposeRecordedExperienceWrapper, self).__init__(env)
        self.env = env
        self.experience = self.recursive_experience_find(env).experience

    def recursive_experience_find(self, env):
        if env.__class__.__name__ != 'ExperienceRecordingWrapper':
            return self.recursive_experience_find(env.env)
        else:
            return env

    def get_experience(self):
        return self.recursive_experience_find(self.env).experience

    def set_experience(self, new_experience):
        self.recursive_experience_find(self.env).experience = new_experience


class ExposeExperiencePlaybackWrapper(gym.Wrapper):
    def __init__(self, env):
        super(ExposeExperiencePlaybackWrapper, self).__init__(env)
        self.env = env
        self.experience = self.recursive_experience_find(env).experience_recording

    def recursive_experience_find(self, env):
        if env.__class__.__name__ != 'ExperiencePlaybackWrapper':
            return self.recursive_experience_find(env.env)
        else:
            return env

    def get_experience(self):
        return self.recursive_experience_find(self.env).experience_recording

    def set_experience(self, new_experience):
        self.recursive_experience_find(self.env).env_seed = new_experience.meta_info
        self.recursive_experience_find(self.env).experience_recording = new_experience


class ExperiencePlaybackWrapper(gym.Wrapper):
    """
    The ExperiencePlaybackWrapper is used to reset the environment to a previously recorded state.
    When giving the additional arguments seed, experience_recording and checkpoint, the agent will execute the
    actions that have been recorded in the experience_recording
    """

    def __init__(self, env, experience_recording=None, checkpoint=None):
        """
        :param checkpoint: the checkpoint string. Can be 'latest' to use the lastest checkpoint in the recording.
        """
        super().__init__(env)
        self.experience_recording = experience_recording
        assert self.experience_recording is not None
        self.env_seed = self.experience_recording.meta_info
        if checkpoint == 'latest':
            self.checkpoint = next(reversed(self.experience_recording.checkpoints))
        else:
            self.checkpoint = checkpoint

    def reset(self, **kwargs):
        # set seed
        if self.env_seed:
            self.env.seed(self.env_seed)
        super().reset()

        # read in recorded experience
        # we have to catch exceptions in the reset method because we throw an InventoryMatch Exception
        # if the inventories of live and recorded trajectories don't match
        try:
            res = self.experience_recording.replay_on_env(
                env=self.env,
                checkpoint=self.checkpoint,
                noop_action=self.env.action_space.noop()
            )
            self.env, obs = res
            return obs

        except Exception as e:
            print(e, 'Something went wrong when replaying this trajectory. Trying new reset.')
            return self.reset(**kwargs)


class InitialSeedWrapper(gym.Wrapper):
    """
    To improve reproducibility, we will set an initial seed for numpy and then draw random integer seeds for
    the environment.
    """

    def __init__(self, env, initial_seed):
        super().__init__(env)
        self.env = env
        self.env_seed = None
        self.initial_seed = initial_seed
        np.random.seed(int(initial_seed))

    def reset(self):
        seed = np.random.randint(1, 8000000)
        logging.info("SeedWrapper: Seed of this env is {} based on initial seed {}".format(seed, self.initial_seed))
        self.env.seed(seed)
        self.env_seed = seed
        state = self.env.reset()
        return state


class SeedWrapper(gym.Wrapper):
    def __init__(self, env, seed):
        super().__init__(env)
        self.env = env
        self.seed = seed

    def reset(self):
        self.env.seed(self.seed)
        state = self.env.reset()
        return state


class ExperienceRecordingWrapper(gym.Wrapper):
    """
    The ExperienceRecordingWrapper records (state, action, reward, done, info) experiences.
    Also, a seed must be set with the SeedWrapper before using the ExperienceRecordingWrapper because of the
    way minerl seeding works.

    If record_pov is True, the pov of the trajectories will be saved as well. Otherwise, only the inventories will
    be saved s.t. we can determine if the experience replay has managed to replay a recorded experience correctly.
    (Due to minerl's random drop directions, it happens quite often that an item is not picked up correctly, eg)
    """

    def __init__(self, env, save_on_disk=False, save_path=None, env_is_obtain_diamond_sparse=False, record_obs=False,
                 extend_experience=False, experience_recording=None):
        super(ExperienceRecordingWrapper, self).__init__(env)

        if save_on_disk:
            assert save_path is not None, "When setting save_on_disk to True, you have to provide a save_path"

        self.env = env
        self.seed = env.env_seed
        self.save_on_disk = save_on_disk
        # self.save_path = save_path
        # if self.save_path:
        #   os.makedirs(self.save_path, exist_ok=True)
        self.record_pov = record_obs

        self.current_obs = None
        self.extend_experience = extend_experience
        if experience_recording:
            assert experience_recording is not None, "If extending an existing experience, you must provide such an " \
                                                     "existing experience... "
            self.experience = experience_recording
            self.experience_recording_after_init = experience_recording
        else:
            self.experience = None
        self.counter = 0
        self.env_is_obtain_diamond_sparse = env_is_obtain_diamond_sparse

    def step(self, action):
        next_obs, reward, done, info = self.env.step(action)
        self.current_obs = next_obs

        if self.record_pov:
            self.experience.append(state=self.current_obs, action=action, reward=reward, done=done, info=info)
        else:
            # drop 'pov' from observation before recording it
            rec_obs = {
                'inventory': self.current_obs['inventory'],
                'equipped_items': self.current_obs['equipped_items']
            }
            self.experience.append(state=rec_obs, action=action, reward=reward, done=done, info=info)

        return next_obs, reward, done, info

    def reset(self, **kwargs):
        # reset experience: if we want to extend an experience, start with the experience recording used at init time
        # as experience. otherwise, use a new Experience()
        # reset env, such that new seed is set
        self.current_obs = self.env.reset()
        if self.extend_experience:
            # self.extend_experience = self.experience_recording_after_init
            self.experience = self.env.experience_recording
        else:
            self.experience = Experience()
        # now read in new seed and save new experience with this seed
        self.seed = self.env.env_seed
        self.experience.meta_info = self.seed
        return self.current_obs


class ExperienceSamplingWrapper(gym.Wrapper):
    """
    The ExperienceSamplingWrapper is used to reset the environment to a previously recorded state.
    """

    def __init__(self, env, experience_folder):
        super().__init__(env)
        assert experience_folder is not None
        self.env = env
        self.experience_folder = experience_folder
        self.experience_filenames = os.listdir(experience_folder)
        # remove the subfolder 'corrupted' from the list of experience filenames
        if 'corrupted' in self.experience_filenames:
            self.experience_filenames.remove('corrupted')

        # this is set to true so that recording wrapper works correctly
        self.experience_recording = True

    def reset(self, **kwargs):
        # sample randomly an experience from the experience folder
        while True:
            corrupted = True
            # look for uncorrupted sequence
            while corrupted:
                experience_file = np.random.choice(self.experience_filenames)
                self.experience_recording = Experience.load(os.path.join(self.experience_folder, experience_file))
                # test if experience already has the 'corrupted' attribute
                try:
                    corrupted = self.experience_recording.corrupted
                except:
                    # experience_recording does not have corrupted attribute, so we set it
                    setattr(self.experience_recording, 'corrupted', False)
                    # we have to save the recording again for the change it attribute to be persistent
                    self.experience_recording.save(os.path.join(self.experience_folder, experience_file))

            env_seed = self.experience_recording.meta_info
            checkpoint = next(reversed(self.experience_recording.checkpoints))

            # set seed and reset environment
            if env_seed:
                self.env.seed(env_seed)
            super().reset()

            # we try to playback the env on the environment for 3 times. if it doesn't work 3 times, we set its
            # state to corrupted and move on

            for i in range(3):
                # we cannot be certain, that the sampled experience is a good experience
                # we have to catch exceptions in the reset method because we throw an InventoryMatch Exception
                # if the inventories of live and recorded trajectories don't match
                print('Trying experience replay on env. Trial {}/{}'.format(i + 1, 3))
                try:
                    res = self.experience_recording.replay_on_env(
                        env=self.env,
                        checkpoint=checkpoint,
                        noop_action=self.env.action_space.noop()
                    )
                    self.env, obs = res
                    print('success replaying the trajectory!')
                    return obs

                except Exception as e:
                    print(e, 'Something went wrong when replaying this trajectory.')
                    # return self.reset(**kwargs)
                    super().reset()

            # if we reach this point, the experience is corrupted
            print('repeating this trajectory has not worked. setting corrupted state to true for', experience_file)
            self.experience_recording.corrupted = True
            self.experience_recording.save(os.path.join(self.experience_folder, experience_file))


class RecordingWrapper(gym.Wrapper):
    def __init__(self, env, env_id):
        super(RecordingWrapper, self).__init__(env)
        self.env_id = env_id
        self.episode = None

    def step(self, action):
        state, reward, done, info = self.env.step(action)
        self.episode.append(state=state, action=action, reward=reward, done=done, info=info)
        info['episode'] = self.episode.sequence
        return state, reward, done, info

    def reset(self, **kwargs):
        # TODO: Make preprocessor generic
        if self.env_id == 'MineRLTreechop-v0':
            self.episode = AgentEpisode(MineRLTreechopPklPreprocessor())
        else:
            self.episode = AgentEpisode(MineRLObtainDiamondPklPreprocessor())
        return self.env.reset()


class VecStackedWrapper(gym.Wrapper):
    def __init__(self, env):
        super(VecStackedWrapper, self).__init__(env)
        self._ori_obs_space = self.env.observation_space
        self.observation_space = gym.spaces.Dict()

    def reset(self, **kwargs):
        obs = self.env.reset()
        return obs


class StackWrapper:
    def __init__(self, env, dataset, n_cpu):
        self.env = env
        self.n_cpu = n_cpu
        self.agent_states = (AgentState(sequence_length=dataset.SEQ_LENGTH,
                                        data_transform=dataset.DATA_TRANSFORM),) * n_cpu

    def _to_dict_list(self, state):
        res_len = len(state[list(state.keys())[0]])
        res = [{k: None for k in state.keys()}, ] * res_len
        for k, v in state.items():
            for i in range(res_len):
                res[i][k] = v[i]
        return res

    def reset(self):
        # fill up agentstates with seq_len observations
        for _ in range(self.agent_states[0].required_sequence_length):
            actions = (self.env.action_space.noop(),) * self.n_cpu
            for action in actions:
                action["jump"] = 1
            state, reward, done, info = self.env.step(actions)
            state = self._to_dict_list(state)
            for i in range(len(self.agent_states)):
                self.agent_states[i].update(state=state[i],
                                            action=actions[i],
                                            reward=reward[i],
                                            done=done[i],
                                            info=info[i])
        return self.agent_states

    def step(self, actions):
        return self.env.step(actions)


class FrameSkip(gym.Wrapper):
    """Return every `skip`-th frame and repeat given action during skip.
    Note that this wrapper does not "maximize" over the skipped frames.
    """

    def __init__(self, env, skip=1):
        super().__init__(env)

        self._skip = skip
        self.noop = deepcopy(env.action_space.noop())

    def step(self, action):

        # accumulate reward over skipped frames
        total_reward = 0.0

        # fix none / 0 issue
        for a in ["craft", "equip", "nearbyCraft", "nearbySmelt", "place"]:
            if a in action:
                action[a] = "none" if action[a] == 0 else action[a]

        # check if we perform a crafting, equipping or smelting action
        if "craft" in action and action["craft"] != "none":
            craft_equip_smelt_place = True
        elif "equip" in action and action["equip"] != "none":
            craft_equip_smelt_place = True
        elif "nearbyCraft" in action and action["nearbyCraft"] != "none":
            craft_equip_smelt_place = True
        elif "nearbySmelt" in action and action["nearbySmelt"] != "none":
            craft_equip_smelt_place = True
        elif "place" in action and action["place"] != "none":
            craft_equip_smelt_place = True
        else:
            craft_equip_smelt_place = False

        skip = self._skip if not PackageManager.get_instance().enabled() else PackageManager.get_instance().dataset.FRAME_SKIP
        for _ in range(skip):

            # take env step
            obs, reward, done, info = self.env.step(action)

            # accumulate reward
            total_reward += reward

            # distinguish between crafting and env actions
            if craft_equip_smelt_place:
                action = self.noop

            if done:
                break

        return obs, total_reward, done, info


class SaveStepWrapper(gym.Wrapper):
    """Sometimes the env.step only returns 3 values causing an exception.
    We don't know why this happens so we include this wrapper here.
    """

    def __init__(self, env):
        super().__init__(env)
        self.max_patience = 2
        self.patience = self.max_patience

    def step(self, action):
        # take env step
        step_return = self.env.step(action)
        while len(step_return) != 4 and self.patience >= 0:
            step_return = self.env.step(action)
            self.patience -= 1

        # unpack state and reset patience
        obs, reward, done, info = step_return
        self.patience = self.max_patience

        return obs, reward, done, info


class CraftEquipWrapper(gym.Wrapper):
    """
    Equips crafted items

    Notes:

    We have performed Wilcoxon statistical test for actions appearing together, under the assumption that they
    follow a Poisson distribution for the design of our action space. For non-imitation learning agents,
    some actions are performed one after another. For example, "nearbyCraft item" and "equip item" significantly
    appear together, and therefore, if we perform action "nearbyCraft wooden pickaxe" we also
    perform "equip wooden pickaxe"
    """
    def __init__(self, env, n_noops=5):
        super().__init__(env)
        self.noop = deepcopy(env.action_space.noop())
        self.n_noops = n_noops

    def step(self, action):
        total_reward = 0.0
        # fix none / 0 issue
        for a in ["craft", "equip", "nearbyCraft", "nearbySmelt", "place"]:
            if a in action:
                action[a] = "none" if action[a] == 0 else action[a]
        if (
                ("craft" in action and action['craft'] != "none") or
                ("nearbyCraft" in action and action["nearbyCraft"] != "none")
        ):
            # get item that has been crafted
            if action['craft'] != "none":
                crafted_item = action['craft']
            if action['nearbyCraft'] != "none":
                crafted_item = action['nearbyCraft']

            # send original action to env to craft item
            obs, reward, done, info = self.env.step(action)
            total_reward += reward

            # wait n noops for item to appear in inventory
            for _ in range(self.n_noops):
                obs, reward, done, info = self.env.step(self.noop)
                total_reward += reward

            # now equip item that had been crafted before and remove item from craft if its in the
            # equip enum
            new_action = self.noop.copy()
            if crafted_item in self.action_space['equip'].values:
                new_action['equip'] = crafted_item
                # send action to env
                obs, reward, done, info = self.env.step(new_action)
                total_reward += reward

        else:
            obs, reward, done, info = self.env.step(action)
            total_reward += reward

        return obs, total_reward, done, info


class MonitorEquipWrapper(gym.Wrapper):
    """
    Monitors the inventory and equips the highest ranked item. Ignores agent equip statements.
    """
    def __init__(self, env):
        super().__init__(env)
        self.noop = deepcopy(env.action_space.noop())
        # list according to minecraft durability specification
        # https://minecraft.gamepedia.com/Pickaxe
        self.item_rank = {
            'iron_pickaxe': 250,
            'iron_axe': 250,
            'stone_pickaxe': 131,
            'stone_axe': 131,
            'wooden_pickaxe': 59,
            'wooden_axe': 59
        }

    def step(self, action):
        if 'equip' not in action:
            return self.env.step(action)

        total_reward = 0.0
        # override agent reward
        action['equip'] = "none"
        # send original action to env to craft item
        obs, reward, done, info = self.env.step(action)
        total_reward += reward

        # find best item in the inventory
        best_item = None
        for item, value in self.item_rank.items():
            # set if the best item is not initialized or a higher rank value was established
            if item in obs['inventory'] and obs['inventory'][item] > 0:
                best_item = item if best_item is None or self.item_rank[best_item] < self.item_rank[item] else best_item

        # set only if a good item exists according to the item_rank and the mainhand has a worse item equipped
        if best_item is not None and obs['equipped_items']['mainhand']['maxDamage'] < self.item_rank[best_item]:
            action = deepcopy(self.noop)
            action['equip'] = best_item
            obs, reward, done, info = self.env.step(action)
            total_reward += reward

        return obs, total_reward, done, info
