import numpy as np

from gym.utils import seeding
from gym.envs import registration
from minihack import MiniHackNavigation, LevelGenerator, RewardManager
from minihack.tiles.rendering import get_des_file_rendering

from envs.minihack import minihack_grid 
from nle import nethack

from util import cprint

from nle.env.base import FULL_ACTIONS

_ACTIONS = tuple(
    [nethack.Command.EAT]
    + list(nethack.CompassDirection)
)

_ACTION2INDEX = {a:i for i,a in enumerate(_ACTIONS)}

DEBUG = False


class MiniHackRoomBinaryChoice(MiniHackNavigation):
    """Environment for "empty" task."""
    def __init__(
        self,
        *args,
        p=0.7,
        rewards=[3, 10],
        reward_spreads=[2, 0],
        reward_dist='uniform',
        size=15,
        n_wall=0,
        n_lava=0,
        n_monster=0,
        skip_tile_p=0.0,
        max_episode_steps=250,
        lit=True,
        seed=None,
        obl_correction=False,
        unexclusive_goals=False,
        goal_hint_p=0.0,
        fixed_environment=False,
        **kwargs
    ):
        kwargs["max_episode_steps"] = max_episode_steps
        kwargs["autopickup"] = kwargs.pop("autopickup", False)
        kwargs["actions"] = kwargs.pop("actions", _ACTIONS)

        kwargs["allow_all_yn_questions"] = kwargs.pop(
            "allow_all_yn_questions", True
        )

        self.size = size
        self.n_wall = n_wall
        self.n_lava = n_lava
        self.n_monster = n_monster
        self.skip_tile_p = skip_tile_p
        self.obl_correction = obl_correction
        self.unexclusive_goals = unexclusive_goals
        self.p = p
        self.goal_hint_p = goal_hint_p
        self.rewards = rewards
        self.reward_spreads = reward_spreads
        self.reward_dist = reward_dist
        self.is_red_goal = False

        self.last_message = ''

        self._init_empty_grid(seed=seed)

        self.lvl_gen = LevelGenerator(w=size, h=size, lit=lit)
        self.reward_manager = RewardManager()

        super().__init__(*args, 
            des_file=self.lvl_gen.get_des(), **kwargs)

        self.penalty_step = 0.0

        self.seed(seed) # sets np_random seed

        self.unseeded_np_random,_ = seeding.np_random()

        self._init_empty_grid()

        self._regenerate_level()

    def _init_empty_grid(self, seed=None):
        """Grid is initially empty, because adversary will create it."""
        # Create an empty grid
        self.grid = minihack_grid.Grid(
            width=self.size, 
            height=self.size,
            seed=seed)

        self.grid.add_custom_objects({
                # Goals
                'R': {'name':'apple', 'symbol':'%'},
                'B': {'name':'banana', 'symbol':'%'},

                # Goal indicators
                'r': {'name':'ruby', 'symbol':'*'},
                'b': {'name':'sapphire', 'symbol':'*'},
            })

    def _set_char(self, char, loc):
        if self.np_random.rand() > self.skip_tile_p:
            self.grid.set_char(char, loc)

    def _regenerate_level(self):
        """
        Generate n_clutter walls and randomly place the agent and goals (an apple and banana)

        Note the goals are two items, with auto-pickup on.
        """
        # === Generate topology ===
        self.grid.clear()

        for _ in range(self.n_wall):
            self._set_char('-', loc='random')

        for _ in range(self.n_lava):
            self._set_char('L', loc='random')

        for _ in range(self.n_monster):
            self._set_char('m', loc='random')

        # self.grid.set_char('>', loc='random')

        # Add agent
        self.grid.set_char('<', loc='random_free')

        # Add apple and banana
        self.grid.set_char('R', loc='random_free')
        self.grid.set_char('B', loc='random_free')

        # === Update rewards ===
        # As shorthand, apple is "red", banana is "blue"
        self.reward_manager = RewardManager()
            
        self.red_reward, self.blue_reward = 0,0

        if self.obl_correction:
            red_outcome = self.unseeded_np_random.rand() # Maintain uniform prior over levels
        else:
            red_outcome = self.np_random.rand() # fixed by seed
        self.is_red_goal = red_outcome < self.p

        if self.is_red_goal:
            self.red_reward = self._sample_reward_for_goal(0)
            if self.unexclusive_goals:
                self.blue_reward = self._sample_reward_for_goal(1)
            cprint(DEBUG, f'Eat the apple, r={self.red_reward}')
        else:
            self.blue_reward = self._sample_reward_for_goal(1)
            if self.unexclusive_goals:
                self.red_reward = self._sample_reward_for_goal(0)
            cprint(DEBUG, f'Eat the banana for, r={self.blue_reward}')

        self.reward_manager.add_eat_event(
            'apple', reward=self.red_reward, terminal_sufficient=True)
        self.reward_manager.add_eat_event(
            'banana', reward=self.blue_reward, terminal_sufficient=True)

        # === Add goal indicators ===
        # Add goal hint (tree for red, gold for banana) next to agent
        if self.np_random.rand() < self.goal_hint_p:
            agent_neighbors = ~self.grid.mask_neighbors('<',first=True) # Moore neighborhood of agent

            if self.is_red_goal:
                non_goal_loc = self.grid.mask('R',first=True) # Location of R
            else:
                non_goal_loc = self.grid.mask('B',first=True) # Location of B
            loc_mask = agent_neighbors & non_goal_loc

            if self.is_red_goal:
                self.grid.set_char('r', mask=loc_mask)
            else:
                self.grid.set_char('b', mask=loc_mask)

        # === Compute shortest path metrics ===
        grid_info = \
            self.grid.get_metrics(
                goal_chars=['R', 'B'],
                clutter_chars=['-', 'm'],
                aliases={
                    'R': 'red',
                    'B': 'blue',
                    '-': 'wall',
                    'm': 'monster'
                })
        self.passable = \
            grid_info['passable']
        self.shortest_path_length = \
            grid_info['shortest_path_lengths']
        self.n_clutter_placed = \
            grid_info['clutter_counts']

        cprint(DEBUG, grid_info)

        # print(f'SEED:{self._level_seed}')
        # print(self.grid.map)
        # print('des file is\n')
        # print(self.grid.des)
        # import pdb; pdb.set_trace()

    def _sample_reward_for_goal(self, idx):
        mean = self.rewards[idx]
        spread = self.reward_spreads[idx]
        if self.reward_dist == 'normal':
            reward = self.unseeded_np_random.normal(mean, spread)
        elif self.reward_dist == 'uniform':
            reward = self.unseeded_np_random.rand()*2*spread + mean - spread
        else:
            raise ValueError(f'Unsupported reward dist {self.reward_dist}.')

        return reward

    def step(self, action):
        obs, r, done, info = super().step(action)

        cur_message = obs['message'].tobytes().decode('utf-8').lower()

        if done:
            # check if eating event
            action_is_y = action == _ACTION2INDEX[nethack.CompassDirection.NW]
            if action_is_y and 'apple' in self.last_message:
                info['target'] = 'red' 
            elif action_is_y and 'banana' in self.last_message:
                info['target'] = 'blue'
            else:
                info['target'] = None
            # print('Target is', info['target'])

        self.last_message = cur_message

        return obs, r, done, info

    def reset(self):
        self._regenerate_level()
        super().update(des_file=self.grid.des)
        return super().reset()

    def reset_agent(self):
        super().update(des_file=self.grid.des)
        return super().reset()

    def seed(self, level_seed, *args, **kwargs):
        super().seed(level_seed, *args, **kwargs)

        self._level_seed = level_seed

        self.np_random, _ = seeding.np_random(level_seed)
        self.grid.seed(level_seed)

    @property
    def goal_color(self):
        if self.is_red_goal:
            return 'red'
        else:
            return 'blue'

    @property
    def des_file(self):
        return self.grid.des

    @property
    def grid_str(self):
        return self.grid.map.__str__()

    def render(self, mode='level'):
        if mode == 'level':
            des_file = self.grid.des
            return np.asarray(get_des_file_rendering(des_file, wizard=True))
        else:
            return super().render(mode=mode)

class MiniHackBinaryChoice5x5(MiniHackRoomBinaryChoice):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, 
            size=5, 
            n_wall=8,
            max_episode_steps=50, 
            **kwargs)

class MiniHackBinaryChoice15x15(MiniHackRoomBinaryChoice):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, 
            size=15, 
            n_wall=50,
            max_episode_steps=250, **kwargs)

class MiniHackBinaryChoiceLava5x5(MiniHackRoomBinaryChoice):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, 
            size=5,
            n_lava=8,
            max_episode_steps=50, 
            **kwargs)

class MiniHackBinaryChoiceLava15x15(MiniHackRoomBinaryChoice):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, 
            size=15,
            n_lava=50,
            max_episode_steps=250, **kwargs)

class MiniHackBinaryChoiceLavaNonExBinomial7x7(MiniHackRoomBinaryChoice):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, 
            size=7,
            n_lava=20,
            skip_tile_p=0.5,
            unexclusive_goals=True,
            max_episode_steps=50, **kwargs)


if hasattr(__loader__, 'name'):
  module_path = __loader__.name
elif hasattr(__loader__, 'fullname'):
  module_path = __loader__.fullname


registration.register(
    id='MiniHack-BinaryChoice5x5-v0',
    entry_point=module_path + ':MiniHackBinaryChoice5x5',
    max_episode_steps=50
)

registration.register(
    id="MiniHack-BinaryChoice15x15-v0",
    entry_point=module_path + ':MiniHackBinaryChoice15x15',
    max_episode_steps=250
)

registration.register(
    id='MiniHack-BinaryChoiceLava5x5-v0',
    entry_point=module_path + ':MiniHackBinaryChoiceLava5x5',
    max_episode_steps=50
)

registration.register(
    id="MiniHack-BinaryChoiceLava15x15-v0",
    entry_point=module_path + ':MiniHackBinaryChoiceLava15x15',
    max_episode_steps=250
)

registration.register(
    id="MiniHack-BinaryChoiceLavaNonExBinomial7x7-v0",
    entry_point=module_path + ':MiniHackBinaryChoiceLavaNonExBinomial7x7',
    max_episode_steps=50
)
