import gym
import numpy as np
import os

from .image_task import ImageTask


class MiniHackObsWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, dtype=np.uint8, shape=(84, 84, 3))

    def observation(self, obs):
        obs = obs["pixel_crop"]
        return obs


class MiniHackMakeVecSafeWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.basedir = os.getcwd()


    def step(self, action: int):
        x = self.env.step(action)
        return x

    def reset(self):
        x = self.env.reset()
        return x

    def close(self):
        self.env.close()



def make_minihack(
    env_name,
    observation_keys=["pixel_crop"],
    reward_win=1,
    reward_lose=0,
    penalty_time=0.0,
    penalty_step=-0.001,
    penalty_mode="constant",
    character="mon-hum-neu-mal",
    savedir=None,
    **kwargs,
):
    import minihack

    env = gym.make(
        f"MiniHack-{env_name}",
        observation_keys=observation_keys,
        reward_win=reward_win,
        reward_lose=reward_lose,
        penalty_time=penalty_time,
        penalty_step=penalty_step,
        penalty_mode=penalty_mode,
        character=character,
        savedir=savedir,
        **kwargs,
    )  # each env specifies its own self._max_episode_steps
    print("Environment created:", env)
    env = MiniHackMakeVecSafeWrapper(env)
    env = MiniHackObsWrapper(env)
    return env


def get_single_minihack_task(task_id, action_space_id, env_name, num_timesteps, eval_mode=False, **kwargs):
    return ImageTask(
        task_id=task_id,
        action_space_id=action_space_id,
        env_spec=lambda: make_minihack(env_name, **kwargs),
        num_timesteps=num_timesteps,
        time_batch_size=1,  # no framestack
        eval_mode=eval_mode,
        image_size=[84, 84],
        grayscale=False,
    )