"""
This class mirrors the MultiGrid MultiRoomBinaryChoice environments
inside of a MiniHack environment.
"""
import re

from gym.envs import registration
import gym

from nle import nethack
from nle.env.base import FULL_ACTIONS
from nle.nethack import Command, CompassDirection
from minihack import MiniHackNavigation, LevelGenerator, RewardManager
from envs.minihack import minihack_grid 
from minihack.tiles.rendering import get_des_file_rendering

from envs.multigrid.multiroom_binary_choice import *
from util import cprint


DEBUG = False


_ACTIONS = \
    [nethack.Command.EAT,] \
    + list(nethack.CompassDirection)

_ACTION2INDEX = {a:i for i,a in enumerate(_ACTIONS)}

_FULL_ACTION2INDEX = {a:i for i,a in enumerate(FULL_ACTIONS)}


class MiniHackMultiRoomBinaryChoice(MiniHackNavigation):
    def __init__(self,
        *args,
        min_n_rooms,
        max_n_rooms,
        max_room_size=10,
        grid_size=14, 
        fully_observable=False,
        p=0.7,
        goal_hint_p=0.0,
        rewards=[3, 10],
        reward_spreads=[2, 0],
        reward_dist='uniform',
        n_monster=0,
        n_trap=0,
        door_state='closed',
        max_episode_steps=250,
        lit=True,
        autopickup=False,
        penalty_step=0.0,
        seed=None,
        obl_correction=False,
        use_learned_beliefs=False,
        unexclusive_goals=False,
        fixed_environment=False,
        use_skeleton_key=False,
        **kwargs):

        # Create multigrid instance
        self.minigrid_env = MultiRoomEnv(
            minNumRooms=min_n_rooms,
            maxNumRooms=max_n_rooms,
            maxRoomSize=max_room_size,
            gridSize=grid_size,
            p=p, 
            rewards=rewards, 
            reward_spreads=reward_spreads, 
            seed=seed, 
            fixed_environment=fixed_environment, 
            obl_correction=obl_correction)

        self.fixed_environment = fixed_environment

        self.fully_observable = fully_observable
        if fully_observable:
            kwargs["wizard"] = True

        self.p = p

        self.goal_hint_p = goal_hint_p
        self.rewards = rewards
        self.reward_spreads = reward_spreads
        self.reward_dist = reward_dist
        self.obl_correction = obl_correction
        self.use_learned_beliefs = use_learned_beliefs
        self.unexclusive_goals = unexclusive_goals
        self.is_red_goal = False

        if obl_correction and use_learned_beliefs:
            self.belief_dist = { # Belief over R,B goals
                'goal_color': [0.5, 0.5] 
            }
        else:
            self.belief_dist = { # True dist over R,B goals
                'goal_color': [p, 1-p]
            }

        self.last_message = ''

        self.num_mon = n_monster
        self.num_trap = n_trap

        self.use_skeleton_key = use_skeleton_key
        self.key_inventory_action_index = None
        self.door_state = door_state
        if use_skeleton_key:
            self.door_state = 'locked'
            _ACTIONS.append(Command.APPLY)
        else:
            if self.door_state == 'locked':
                _ACTIONS.append(Command.KICK)
            else:
                _ACTIONS.append(Command.OPEN)
        _ACTION2INDEX = {a:i for i,a in enumerate(_ACTIONS)}

        kwargs["actions"] = kwargs.pop("actions", tuple(_ACTIONS))
        kwargs["allow_all_yn_questions"] = kwargs.pop(
            "allow_all_yn_questions", True
        )
        kwargs["max_episode_steps"] = max_episode_steps
        kwargs["autopickup"] = kwargs.pop("autopickup", False)

        self.penalty_step = penalty_step
        kwargs["penalty_step"] = self.penalty_step

        # Seeding
        self._init_empty_grid(seed=seed)
        
        super().__init__(*args, des_file=self.grid.des, **kwargs)

        self.seed(seed)
        self.unseeded_np_random,_ = seeding.np_random()
        self._regenerate_level()

    def _reset_episodic_counts(self):
        self.episodic_counts = {
            'solved_room_count': 0,
            'red_target': 0,
            'blue_target': 0,
            'none_target': 0,
        }

    def _init_empty_grid(self, seed=None):
        height = self.minigrid_env.height
        width = self.minigrid_env.width

        self.grid = minihack_grid.Grid(
            width=width, 
            height=height,
            seed=seed)

        self.grid.add_custom_objects({
                # Goals
                'R': {'name':'apple', 'symbol':'%'},
                'B': {'name':'banana', 'symbol':'%'},

                # Goal indicators
                'r': {'name':'ruby', 'symbol':'*'},
                'b': {'name':'sapphire', 'symbol':'*'},

                # Key
                'k': {'name':'skeleton key', 'symbol':'('}
            })

    def _regenerate_level(self):
        # Reset the level contents based on multigrid env
        if not self.fixed_environment:
            self.minigrid_env.reset()

        self._reset_episodic_counts()

        self.grid.clear(' ')
        self._match_grid_to_minigrid_env()

        # @todo: Add additional monsters and objects to the walkable area
        # === Add additional items ===
        if self.use_skeleton_key:
            self.grid.set_char('k', loc=self.grid.agent_start_loc)

        # === Update rewards ===
        # As shorthand, apple is "red", banana is "blue"
        self.reward_manager = RewardManager()
            
        self.red_reward, self.blue_reward = 0,0

        red_p = self.p
        if self.obl_correction:
            red_outcome = self.unseeded_np_random.rand() # Maintain uniform prior over levels 
            if self.use_learned_beliefs:
                red_p = self.belief_dist['goal_color'][0]
        else:
            red_outcome = self.np_random.rand() # fixed by seed

        self.is_red_goal = red_outcome < red_p

        if self.is_red_goal:
            self.red_reward = self._sample_reward_for_goal(0)
            if self.unexclusive_goals:
                self.blue_reward = self._sample_reward_for_goal(1)
            cprint(DEBUG, f'Eat the apple, r={self.red_reward}')
        else:
            self.blue_reward = self._sample_reward_for_goal(1)
            if self.unexclusive_goals:
                self.red_reward = self._sample_reward_for_goal(0)
            cprint(DEBUG, f'Eat the banana for, r={self.blue_reward}')

        self.reward_manager.add_eat_event(
            'apple', reward=self.red_reward, terminal_sufficient=True)
        self.reward_manager.add_eat_event(
            'banana', reward=self.blue_reward, terminal_sufficient=True)

        # === Compute shortest path metrics ===
        grid_info = \
            self.grid.get_metrics(
                goal_chars=['R', 'B'],
                clutter_chars=['-', 'm'],
                aliases={
                    'R': 'red',
                    'B': 'blue',
                    '-': 'wall',
                    'm': 'monster'
                })
        self.passable = \
            grid_info['passable']
        self.shortest_path_length = \
            grid_info['shortest_path_lengths']
        self.n_clutter_placed = \
            grid_info['clutter_counts']

    def _match_grid_to_minigrid_env(self):
        grid = self.grid
        env = self.minigrid_env

        door_pos = []
        goal_pos = None
        empty_strs = 0
        empty_str = True
        env_map = []

        locked_doors = self.door_state == 'locked'

        for y in range(env.height):
            for x in range(env.width):
                cell = env.grid.get(x,y)
                char = ' '
                if cell is not None:
                    if cell.type == 'wall':
                        char = '|'
                    elif cell.type == 'door':
                        char = '+' if locked_doors else 'd'
                    elif cell.type == 'floor':
                        char = '.'
                    elif cell.type == 'lava':
                        char = 'L'
                    elif cell.type == 'ball':
                        if (x,y) == tuple(env.red_pos):
                            char = 'R'
                        else:
                            char = 'B'
                    elif cell.type == 'agent':
                        char = '<'

                grid.set_char(char, loc=(x,y))

        # Set walkable area
        walkable_mask = None
        for room in env.rooms:
            tl_x, tl_y = room.top
            tl_x, tl_y = tl_x + 1, tl_y + 1
            w, h = room.size
            br_x, br_y = tl_x + w - 1, tl_y + h - 1
            for y in range(tl_y, br_y):
                for x in range(tl_x, br_x):
                    if self.grid.get_char(loc=(x,y)) == ' ':
                        self.grid.set_char('.', loc=(x,y))

    def _sample_reward_for_goal(self, idx):
        mean = self.rewards[idx]
        spread = self.reward_spreads[idx]
        if self.reward_dist == 'normal':
            reward = self.unseeded_np_random.normal(mean, spread)
        elif self.reward_dist == 'uniform':
            reward = self.unseeded_np_random.rand()*2*spread + mean - spread
        else:
            raise ValueError(f'Unsupported reward dist {self.reward_dist}.')

        return reward

    def step(self, action):
        obs, r, done, info = super().step(action)

        cur_message = obs['message'].tobytes().decode('utf-8').lower()

        if 'what do you want to use or apply?' in cur_message:
            # Automatically choose key
            _actions_tmp = self._actions
            self._actions = FULL_ACTIONS
            obs, r, done, info = super().step(self.key_inventory_action_index)
            self._actions = _actions_tmp
            cur_message = obs['message'].tobytes().decode('utf-8').lower()

        elif 'unlock it?' in cur_message:
            _actions_tmp = self._actions
            self._actions = FULL_ACTIONS
            obs, r, done, info = super().step(_FULL_ACTION2INDEX[ord('y')])
            self._actions = _actions_tmp
            cur_message = obs['message'].tobytes().decode('utf-8').lower()

        if done:
            # check if eating event
            action_is_y = action == _ACTION2INDEX[nethack.CompassDirection.NW]
            num_rooms = len(self.minigrid_env.rooms)
            if action_is_y and 'apple' in self.last_message:
                self.episodic_counts['red_target'] = 1
                self.episodic_counts['solved_room_count'] = num_rooms
            elif action_is_y and 'banana' in self.last_message:
                self.episodic_counts['blue_target'] = 1
                self.episodic_counts['solved_room_count'] = num_rooms
            else:
                self.episodic_counts['none_target'] = 1
            # print('Target is', info['target'])

            info['episodic_counts'] = self.episodic_counts

        self.last_message = cur_message

        return obs, r, done, info

    def reset(self):
        self._regenerate_level()
        super().update(des_file=self.grid.des)

        if self.fully_observable:
            super().reset()

            for c in "#wizintrinsic\rt\r\r#wizmap\r#wizwish\ra potion of object detection\r":
                obs, sds = self.env.step(ord(c))
            msg = (
                obs[self._original_observation_keys.index("message")]
                .tobytes()
                .decode('utf-8')
            )

            for c in f"q{msg[0]}":
                obs, sds = self.env.step(ord(c))

            _actions_tmp = self._actions
            self._actions = FULL_ACTIONS
            obs, _, _, _ = super().step(self._actions.index(nethack.MiscDirection.WAIT))
            self._actions = _actions_tmp
        else:
            obs = super().reset()

        if self.use_skeleton_key:
            # Automatically pick up key
            _actions_tmp = self._actions
            self._actions = FULL_ACTIONS
            obs, r, done, info = super().step(_FULL_ACTION2INDEX[Command.PICKUP])
            self._actions = _actions_tmp

            cur_message = obs['message'].tobytes().decode('utf-8').lower()

            self.key_inventory_action_index = _FULL_ACTION2INDEX[ord(cur_message[0])]

        return obs

    def reset_agent(self):
        super().update(des_file=self.grid.des)
        return super().reset()

    def seed(self, level_seed, *args, **kwargs):
        super().seed(level_seed, *args, **kwargs)

        self._level_seed = level_seed

        self.np_random, _ = seeding.np_random(level_seed)
        self.grid.seed(level_seed)

        self.minigrid_env.seed(level_seed)

    @property
    def belief_spec(self):
        return {
            'goal_color': {
                'type': 'categorical',
                'size': 2
            }
        }

    @property
    def belief_tokens(self):
        red_outcome = self.np_random.rand()
        is_red_goal = red_outcome < self.p # sample from true dist

        tokens = {
            'goal_color': np.array([is_red_goal, not is_red_goal], dtype=np.int32)
        }
        return tokens

    def set_belief_dist(self, belief_dist):
        self.belief_dist = belief_dist

    @property
    def aux_properties(self):
        return {
            'red_goal': self.is_red_goal,
            'num_rooms': len(self.minigrid_env.rooms),
            'belief_red': self.belief_dist['goal_color'][0]
        }

    @property
    def des_file(self):
        return self.grid.des

    @property
    def grid_str(self):
        return self.grid.map.__str__()

    def render(self, mode='level'):
        if mode == 'level':
            des_file = self.grid.des
            return np.asarray(get_des_file_rendering(des_file, wizard=True))
        else:
            return super().render(mode=mode)


class MiniHackMultiRoomBinaryChoiceNRooms(MiniHackMultiRoomBinaryChoice):
    def __init__(
        self, 
        *args,
        seed=None, 
        fixed_environment=False,
        num_rooms=1,
        p=0.0, 
        rewards=(0,1), 
        reward_spreads=(0,0), 
        use_walls=False,
        obl_correction=False,
        use_learned_beliefs=False,
        fully_observable=False,
        use_skeleton_key=False,
        max_episode_steps=500,
        **kwargs):
        super().__init__(
            *args,
            min_n_rooms=num_rooms,
            max_n_rooms=num_rooms,
            max_room_size=5,
            grid_size=21,
            fully_observable=fully_observable,
            p=p,
            rewards=rewards,
            reward_spreads=reward_spreads,
            reward_dist='uniform',
            door_state='locked',
            use_skeleton_key=use_skeleton_key,
            obl_correction=obl_correction,
            use_learned_beliefs=use_learned_beliefs,
            fixed_environment=fixed_environment,
            seed=seed,
            **kwargs
        )

class MiniHackMultiRoomBinaryChoiceN4Random(MiniHackMultiRoomBinaryChoice):
    def __init__(
        self, 
        seed=None, 
        fixed_environment=False,
        p=0.7, 
        rewards=(3,10), 
        reward_spreads=(2,0), 
        use_walls=False,
        obl_correction=False,
        use_learned_beliefs=False,
        use_skeleton_key=False,
        max_episode_steps=250):
        super().__init__(
            *args,
            min_n_rooms=1,
            max_n_rooms=4,
            max_room_size=5,
            grid_size=21,
            fully_observable=False,
            p=p,
            rewards=rewards,
            reward_spreads=reward_spreads,
            reward_dist='uniform',
            door_state='locked',
            use_skeleton_key=use_skeleton_key,
            obl_correction=obl_correction,
            use_learned_beliefs=use_learned_beliefs,
            fixed_environment=fixed_environment,
            seed=seed,
            **kwargs
        )


class MiniHackMultiRoomBinaryChoiceN4RandomSKey(MiniHackMultiRoomBinaryChoiceN4Random):
    def __init__(
        self,
        *args,
        **kwargs
    ):
        super().__init__(
            *args,
            use_skeleton_key=True,
            **kwargs
        )

class MiniHackMultiRoomBinaryChoiceN6Random(MiniHackMultiRoomBinaryChoice):
    def __init__(
        self, 
        *args,
        seed=None, 
        fixed_environment=False, 
        p=0.7, 
        rewards=(3,10), 
        reward_spreads=(2,0), 
        use_walls=False,
        obl_correction=False,
        use_learned_beliefs=False,
        use_skeleton_key=False,
        max_episode_steps=300,
        **kwargs):
        super().__init__(
            min_n_rooms=1,
            max_n_rooms=6,
            max_room_size=5,
            grid_size=21,
            fully_observable=False,
            p=p,
            rewards=rewards,
            reward_spreads=reward_spreads,
            reward_dist='uniform',
            door_state='locked',
            use_skeleton_key=use_skeleton_key,
            obl_correction=obl_correction,
            use_learned_beliefs=use_learned_beliefs,
            fixed_environment=fixed_environment,
            seed=seed,
        )

class MiniHackMultiRoomBinaryChoiceN6RandomSKey(MiniHackMultiRoomBinaryChoiceN6Random):
    def __init__(
        self,
        *args,
        **kwargs
    ):
        super().__init__(
            *args,
            use_skeleton_key=True,
            **kwargs
        )

class MiniHackMultiRoomBinaryChoiceN8Random(MiniHackMultiRoomBinaryChoice):
    def __init__(
        self, 
        *args,
        seed=None, 
        fixed_environment=False, 
        p=0.7, 
        rewards=(3,10), 
        reward_spreads=(2,0), 
        use_walls=False,
        obl_correction=False,
        use_learned_beliefs=False,
        use_skeleton_key=False,
        max_episode_steps=300,
        **kwargs):
        super().__init__(
            min_n_rooms=1,
            max_n_rooms=8,
            max_room_size=5,
            grid_size=21,
            fully_observable=False,
            p=p,
            rewards=rewards,
            reward_spreads=reward_spreads,
            reward_dist='uniform',
            door_state='locked',
            use_skeleton_key=use_skeleton_key,
            obl_correction=obl_correction,
            use_learned_beliefs=use_learned_beliefs,
            fixed_environment=fixed_environment,
            seed=seed,
        )

class MiniHackMultiRoomBinaryChoiceN8RandomSKey(MiniHackMultiRoomBinaryChoiceN8Random):
    def __init__(
        self,
        *args,
        **kwargs
    ):
        super().__init__(
            *args,
            use_skeleton_key=True,
            **kwargs
        )

if hasattr(__loader__, 'name'):
  module_path = __loader__.name
elif hasattr(__loader__, 'fullname'):
  module_path = __loader__.fullname


def set_global(name, value):
    globals()[name] = value

def _create_constructor(num_rooms):
    def constructor(self, **kwargs):
        return MiniHackMultiRoomBinaryChoiceNRooms.__init__(self, 
            num_rooms=num_rooms,
            **kwargs)
    return constructor

MAX_N = 20
for n in range(1,MAX_N+1):
    class_name = f"MiniHackMultiRoomBinaryChoice{n}Rooms"
    env = type(class_name, (MiniHackMultiRoomBinaryChoiceNRooms, ), {
        "__init__": _create_constructor(n),
    })
    set_global(class_name, env)
    # print(module_path + f':{class_name}')
    registration.register(
        id=f'MiniHack-MultiRoomBinaryChoice-N{n}-v0', 
        entry_point=module_path + f':{class_name}',
        max_episode_steps=500)


registration.register(
    id='MiniHack-MultiRoomBinaryChoice-N4-Random-v0',
    entry_point=module_path + ':MiniHackMultiRoomBinaryChoiceN4Random',
    max_episode_steps=250
)

registration.register(
    id='MiniHack-MultiRoomBinaryChoice-N4-Random-SKey-v0',
    entry_point=module_path + ':MiniHackMultiRoomBinaryChoiceN4RandomSKey',
    max_episode_steps=250
)

registration.register(
    id='MiniHack-MultiRoomBinaryChoice-N6-Random-v0',
    entry_point=module_path + ':MiniHackMultiRoomBinaryChoiceN6Random',
    max_episode_steps=300
)

registration.register(
    id='MiniHack-MultiRoomBinaryChoice-N6-Random-SKey-v0',
    entry_point=module_path + ':MiniHackMultiRoomBinaryChoiceN6RandomSKey',
    max_episode_steps=300
)

registration.register(
    id='MiniHack-MultiRoomBinaryChoice-N8-Random-v0',
    entry_point=module_path + ':MiniHackMultiRoomBinaryChoiceN8Random',
    max_episode_steps=300
)

registration.register(
    id='MiniHack-MultiRoomBinaryChoice-N8-Random-SKey-v0',
    entry_point=module_path + ':MiniHackMultiRoomBinaryChoiceN8RandomSKey',
    max_episode_steps=300
)