import re

import numpy as np

from gym.utils import seeding
from gym.envs import registration
from minihack import MiniHackNavigation, LevelGenerator, RewardManager
from minihack.tiles.rendering import get_des_file_rendering

from envs.minihack import minihack_grid 
from nle import nethack
from nle.env.base import FULL_ACTIONS

from util import cprint


_ACTIONS = tuple(
     list(nethack.CompassDirection)
)

_ACTION2INDEX = {a:i for i,a in enumerate(_ACTIONS)}

_FULL_ACTION2INDEX = {a:i for i,a in enumerate(FULL_ACTIONS)}

DEBUG = False


class MiniHackWeaponChoice(MiniHackNavigation):
    """Environment for "empty" task."""
    def __init__(
        self,
        *args,
        fully_observable=False, # Not supported
        p=1.0,
        reward_dist='uniform',
        n_wall=0,
        n_monster=3,
        tunnel=False,
        melee=False,
        max_episode_steps=250,
        lit=True,
        autopickup=False,
        penalty_step=0.0,
        seed=None,
        obl_correction=False,
        use_learned_beliefs=False,
        goal_hint_p=0.0,
        fixed_environment=False,
        **kwargs
    ):  
        self.melee = melee
        if melee:
            kwargs["character"] = "cav-hum-law-mal"
            n_monster = 5
        else:
            kwargs["character"] = "tou-hum-law-fem" # Tourist by default
            _ACTIONS.append(nethack.Command.ZAP)
            _ACTION2INDEX = {a:i for i,a in enumerate(_ACTIONS)}

        kwargs["max_episode_steps"] = max_episode_steps
        kwargs["autopickup"] = kwargs.pop("autopickup", autopickup)
        kwargs["actions"] = kwargs.pop("actions", _ACTIONS)

        kwargs["allow_all_yn_questions"] = kwargs.pop(
            "allow_all_yn_questions", True
        )

        self.n_wall = n_wall
        self.n_monster = n_monster

        self.tunnel = tunnel
        self.p = p
        self.obl_correction = obl_correction
        self.use_learned_beliefs = use_learned_beliefs
        self.belief_dist = {
            'weapon_present': [0.5]
        }
        self.goal_hint_p = goal_hint_p
        self.reward_dist = reward_dist
        self.is_weapon_present = False
        self.weapon_inventory_index = None
        self.weapon_armor_index = None

        self.last_message = ''

        self._init_empty_grid(seed=seed)

        self.lvl_gen = LevelGenerator(w=1, h=1, lit=lit)
        self.reward_manager = RewardManager()

        super().__init__(*args, 
            des_file=self.lvl_gen.get_des(), **kwargs)

        self.penalty_step = penalty_step
        kwargs["penalty_step"] = penalty_step

        self.seed(seed) # sets np_random seed

        self.unseeded_np_random,_ = seeding.np_random()

        self._regenerate_level()

    def _init_empty_grid(self, seed=None):
        """
        Grid is initially the empty room layout.
        """
        # @todo: Initialize grid with string

        fixed_grid_str_chamber = """
--------------                    
|............|     -------------
|............|     |...........|
|............|     |...........|
|.............#####............|
|............|     |...........|
|............|     |...........|
|............|     -------------
--------------                  
"""
        fixed_grid_str_tunnel = """
                        -------------
                        |...........|
-------------------     |...........|
|..................#####............|
-------------------     |...........|
                        |...........|
                        -------------
"""
        if self.tunnel:
            fixed_grid_str = fixed_grid_str_tunnel
        else:
            fixed_grid_str = fixed_grid_str_chamber

        h,w = minihack_grid.Grid.fixed_grid_str_dim(fixed_grid_str)
        self.grid = minihack_grid.Grid(
            width=w, 
            height=h,
            fixed_grid_str=fixed_grid_str,
            seed=seed)

        self.grid.add_custom_objects({
                # Weapons
                '/': {'name':'death', 'symbol':'/'},
                ')': {'name':'dwarvish mattock', 'symbol':')'},
                '[': {'name':'red dragon scale mail', 'symbol':'['},
            })

        if self.melee:
            # self.boss_name = 'fire elemental'
            # self.boss_symbol = "E"
            self.boss_name = 'pyrolisk'
            self.boss_symbol = "c"
        else:
            self.boss_name = 'zruty'
            self.boss_symbol = "z"

        if self.melee:
            # self.minion_name = 'red naga hatchling'
            # self.minion_symbol = "N"
            self.minion_name = 'pyrolisk'
            self.minion_symbol = "c"
        else:
            self.minion_name = 'jackal'
            self.minion_symbol = 'd'

        self.grid.add_custom_monsters({
                'B': {'name': self.boss_name, 'symbol':self.boss_symbol},
                'j': {'name': self.minion_name, 'symbol': self.minion_symbol}
            })

        if self.tunnel:
            self.room_left_mask = self.grid.mask_rect((1,3), (17,3))
            self.room_left_monster_mask = self.grid.mask_rect((1,3), (1,3))
            self.room_right_mask = self.grid.mask_rect((25,1), (35,5))
            self.room_right_wall_mask = self.grid.mask_rect((26,1), (32,5))
            self.room_right_monster_mask = self.grid.mask_rect((35,1), (35,5))
            self.armor_loc = (17,3)
            self.weapon_loc = (16,3) if self.melee else (17,3)
            self.agent_start_loc = (21,3)
            self.left_door_loc = (18,3)
            self.right_door_loc = (24,3)
        else:
            self.room_left_mask = self.grid.mask_rect((1,1), (12,7))
            self.room_left_monster_mask = self.grid.mask_rect((1,1), (1,7))
            self.room_right_mask = self.grid.mask_rect((20,2), (30,6))
            self.room_right_wall_mask = self.grid.mask_rect((21,2), (30,6))
            self.room_right_monster_mask = self.grid.mask_rect((30,2), (30,6))
            self.armor_loc = (12,4)
            self.weapon_loc = (11,4) if self.melee else (12,4)
            self.agent_start_loc = (16,4)
            self.left_door_loc = (13,4)
            self.right_door_loc = (19,4)

        if self.melee:
            self.goal_mask = self.grid.mask_rect((27,2), (28,6))
        else:
            self.goal_mask = self.room_right_mask
        
        # === Update rewards ===
        # Should consider supporting a dense reward version of this task
        self.reward_manager = RewardManager()

        # Note kill events don't seem to currently work in MiniHack
        # self.reward_manager.add_kill_event('minotaur', reward=10.0, repeatable=True, terminal_required=False)
        # self.reward_manager.add_kill_event('jackal', reward=1.0, repeatable=True, terminal_required=False)

    def _regenerate_level(self):
        """
        Generate n_clutter walls and randomly place the agent and goals (an apple and banana)

        Note the goals are two items, with auto-pickup on.
        """
        # === Generate topology ===
        self.grid.clear()

        # Place minataur in left room
        # self.grid.set_char('H', loc='random_free', mask=~self.room_left_mask)
        # self.grid.set_char('H', loc='random_free', mask=~self.room_left_monster_mask)
        if self.melee:
            self.grid.set_char('B', loc='random_free', mask=~self.room_left_monster_mask)
        else:
            self.grid.set_char('B', loc='random_free', mask=~self.room_left_monster_mask)

        # Place walls in right room
        for i in range(self.n_wall):
            self.grid.set_char('-', loc='random_free', mask=~self.room_right_wall_mask)

        # Place monsters in right room
        for i in range(self.n_monster):
            self.grid.set_char('j', loc='random_free', mask=~self.room_right_monster_mask)

        # Place goal in right room
        self.grid.set_char('>', loc='random_free', mask=~self.goal_mask)

        # Place agent at center of corridor
        self.grid.set_char('<', loc=self.agent_start_loc)

        # Add closed, unlocked doors into each room
        self.grid.set_char('d', loc=self.left_door_loc)
        self.grid.set_char('d', loc=self.right_door_loc)

        # Generate weapon with probability p
        weapon_p = self.p
        if self.obl_correction:
            if self.use_learned_beliefs:
                weapon_p = self.belief_dist['weapon_present'][0]
            self.is_weapon_present = self.unseeded_np_random.rand() < weapon_p
        else:
            self.is_weapon_present = self.np_random.rand() < weapon_p

        if self.is_weapon_present:
            if self.melee:
                self.grid.set_char(')', loc=self.weapon_loc)
                self.grid.set_char('[', loc=self.armor_loc)
            else:
                self.grid.set_char('/', loc=self.weapon_loc)

        if self.melee:
            self.no_damage_bonus = True

        # === Compute shortest path metrics ===
        grid_info = \
            self.grid.get_metrics(
                goal_chars=['>', '/'],
                clutter_chars=['-', 'j'],
                aliases={
                    '>': 'goal',
                    '/': 'weapon',
                })
        self.passable = grid_info['passable']
        self.shortest_path_length = grid_info['shortest_path_lengths']
        self.n_clutter_placed = grid_info['clutter_counts']
        # cprint(DEBUG, grid_info)

        # print(f'SEED:{self._level_seed}')
        # print(self.grid.map)
        # print(self.grid.des)
        # import pdb; pdb.set_trace()

    def step(self, action):
        obs, r, done, info = super().step(action)

        cur_message = obs['message'].tobytes().decode('utf-8').lower()

        match_wand_loc = re.match(r"you see here .* wand", cur_message)
        match_melee_weapon_loc = re.match(r"you see here .* broad pick", cur_message)
        match_armor_loc = re.match(r"you see here .* scale mail", cur_message)

        # match_damage = re.match(r"the red naga hatchling bites!", cur_message)
        match_damage = re.match(r"the fire doesn't feel hot", cur_message)

        if self.melee and match_damage:
            self.no_damage_bonus = False

        if match_wand_loc or match_melee_weapon_loc:
            _actions_tmp = self._actions
            self._actions = FULL_ACTIONS
            obs, _, done, info = super().step(_FULL_ACTION2INDEX[nethack.Command.PICKUP])

            cur_message = obs['message'].tobytes().decode('utf-8').lower()

            match_wand_pickup = re.match(r"^([a-z]) - .* wand", cur_message)
            if match_wand_pickup:
                self.weapon_inventory_index = match_wand_pickup[1]
                if self.weapon_inventory_index is not None:
                    r += 1.0
            else:
                match_melee_weapon_pickup = re.match(r"^([a-z]) - .* broad pick", cur_message)
                if match_melee_weapon_pickup:
                    self.weapon_inventory_index = match_melee_weapon_pickup[1]
                    r += 1.0

                    wield_actions = f'w{self.weapon_inventory_index}'
                    for a in wield_actions:
                        if done: break
                        obs, _, done, info = super().step(_FULL_ACTION2INDEX[ord(a)])

            self._actions = _actions_tmp

        elif match_armor_loc:
            _actions_tmp = self._actions
            self._actions = FULL_ACTIONS
            obs, _, done, info = super().step(_FULL_ACTION2INDEX[nethack.Command.PICKUP])

            cur_message = obs['message'].tobytes().decode('utf-8').lower()

            match_armor_pickup = re.match(r"^([a-z]) - .* scale mail", cur_message)
            if match_armor_pickup:
                self.armor_inventory_index = match_armor_pickup[1]
                if self.armor_inventory_index is not None:
                    # r += 1.0

                    wear_actions = f'TW{self.armor_inventory_index}'
                    for a in wear_actions:
                        if done: break
                        obs, _, done, info = super().step(_FULL_ACTION2INDEX[ord(a)])

            self._actions = _actions_tmp

        elif 'what do you want to zap?' in cur_message:
            if self.weapon_inventory_index is not None:
                wand_idx = _FULL_ACTION2INDEX[ord(self.weapon_inventory_index)]
                _actions_tmp = self._actions
                self._actions = FULL_ACTIONS
                obs, _, done, info = \
                    super().step(_FULL_ACTION2INDEX[ord(self.weapon_inventory_index)])
                cur_message = obs['message'].tobytes().decode('utf-8').lower()
                self._actions = _actions_tmp
        elif f'you kill the {self.boss_name}' in cur_message:
            r += 1.0
        elif f'you kill the {self.minion_name}' in cur_message:
            r += 1.0

        cur_message = obs['message'].tobytes().decode('utf-8').lower()

        self.last_message = cur_message

        if done and self.melee and self.no_damage_bonus:
            end_status = info.get('end_status', None)
            if end_status is not None and end_status.name == 'TASK_SUCCESSFUL':
                r += 2.0

        return obs, r, done, info

    def reset(self):
        self._regenerate_level()
        super().update(des_file=self.grid.des)
        return super().reset()

    def reset_agent(self):
        super().update(des_file=self.grid.des)
        return super().reset()

    def seed(self, level_seed, *args, **kwargs):
        super().seed(level_seed, *args, **kwargs)

        self._level_seed = level_seed

        self.np_random, _ = seeding.np_random(level_seed)
        self.grid.seed(level_seed)

    @property
    def belief_spec(self):
        return {
            'weapon_present': {
                'type': 'categorical',
                'size': 2
            }
        }

    @property
    def belief_tokens(self):
        tokens = {
            'weapon_present': np.array([
                self.is_weapon_present, 
                not self.is_weapon_present], 
                dtype=np.int32)
        }
        return tokens

    def set_belief_dist(self, belief_dist):
        self.belief_dist = belief_dist

    @property
    def des_file(self):
        return self.grid.des

    @property
    def grid_str(self):
        return self.grid.map.__str__()

    def render(self, mode='level'):
        if mode == 'level':
            des_file = self.grid.des
            return np.asarray(get_des_file_rendering(des_file, wizard=True))
        else:
            return super().render(mode=mode)


class MiniHackWeaponChoiceTunnel(MiniHackWeaponChoice):
    def __init__(
        self,
        *args,
        **kwargs
    ):
        super().__init__(
            *args,
            tunnel=True,
            **kwargs
        )

class MiniHackWeaponChoiceMelee(MiniHackWeaponChoice):
    def __init__(
        self,
        *args,
        **kwargs
    ):
        super().__init__(
            *args,
            melee=True,
            **kwargs
        )

class MiniHackWeaponChoiceTunnelMelee(MiniHackWeaponChoice):
    def __init__(
        self,
        *args,
        **kwargs
    ):
        super().__init__(
            *args,
            tunnel=True,
            melee=True,
            **kwargs
        )

class MiniHackWeaponChoiceTunnelMazeMelee(MiniHackWeaponChoice):
    def __init__(
        self,
        *args,
        **kwargs
    ):
        super().__init__(
            *args,
            tunnel=True,
            melee=True,
            n_wall=8,
            **kwargs
        )

if hasattr(__loader__, 'name'):
  module_path = __loader__.name
elif hasattr(__loader__, 'fullname'):
  module_path = __loader__.fullname


registration.register(
    id='MiniHack-WeaponChoice-v0',
    entry_point=module_path + ':MiniHackWeaponChoice',
    max_episode_steps=250
)

registration.register(
    id='MiniHack-WeaponChoiceTunnel-v0',
    entry_point=module_path + ':MiniHackWeaponChoiceTunnel',
    max_episode_steps=250
)

registration.register(
    id='MiniHack-WeaponChoiceMelee-v0',
    entry_point=module_path + ':MiniHackWeaponChoiceMelee',
    max_episode_steps=250
)

registration.register(
    id='MiniHack-WeaponChoiceTunnelMelee-v0',
    entry_point=module_path + ':MiniHackWeaponChoiceTunnelMelee',
    max_episode_steps=250
)

registration.register(
    id='MiniHack-WeaponChoiceTunnelMazeMelee-v0',
    entry_point=module_path + ':MiniHackWeaponChoiceTunnelMazeMelee',
    max_episode_steps=250
)