import re
from collections import defaultdict

import numpy as np

from gym.utils import seeding
from gym.envs import registration
from minihack import MiniHackNavigation, LevelGenerator, RewardManager
from minihack.tiles.rendering import get_des_file_rendering
from minihack.base import MH_NETHACKOPTIONS

from envs.minihack import minihack_grid 
from nle import nethack
from nle.env.base import FULL_ACTIONS

from util import cprint


_ACTIONS = tuple( 
    list(nethack.CompassDirection)
)

_ACTION2INDEX = {a:i for i,a in enumerate(_ACTIONS)}

_FULL_ACTION2INDEX = {a:i for i,a in enumerate(FULL_ACTIONS)}


DEBUG = False


class MiniHackCustomMaze(MiniHackNavigation):
    """Environment for "empty" task."""
    def __init__(
        self,
        *args,
        fully_observable=False,
        fixed_grid_str=None,
        p=[0.3, 0.35, 0.35],
        full_maze=False,
        randomize_monsters=False,
        reward_kill=False,
        reward_dist='uniform',
        n_lava=0,
        n_wall=0,
        n_monster=2,
        generator_dist='constant',
        max_episode_steps=250,
        lit=True,
        autopickup=False,
        penalty_step=0.0,
        seed=None,
        obl_correction=False,
        use_learned_beliefs=False,
        goal_hint_p=0.0,
        fixed_environment=False,
        **kwargs
    ):  
        kwargs["character"] = "bar-hum-law-mal" # Tourist by default
        kwargs["max_episode_steps"] = max_episode_steps
        kwargs["autopickup"] = kwargs.pop("autopickup", autopickup)
        kwargs["actions"] = kwargs.pop("actions", _ACTIONS)

        kwargs["allow_all_yn_questions"] = kwargs.pop(
            "allow_all_yn_questions", True
        )

        self.fully_observable = fully_observable
        if fully_observable:
            kwargs["wizard"] = True
        
        custom_options = ["nudist", "role:bar"]
        if not autopickup:
            custom_options.append('!autopickup')
        kwargs["options"] = tuple(list(MH_NETHACKOPTIONS) + custom_options)

        if fixed_grid_str:
            self.fixed_grid_str = fixed_grid_str
        else:
            fixed_grid_str = """
.............
.............
.............
.............
.............
.............
.............
.............
.............
.............
.............
.............
.............
        """

        self.generator_dist = generator_dist
        self.n_lava = n_lava
        self.n_wall = n_wall
        self.n_monster = n_monster

        self.full_maze = full_maze
        self.reward_kill = reward_kill
        self.randomize_monsters = randomize_monsters
        
        self.p = p
        self.obl_correction = obl_correction
        self.use_learned_beliefs = use_learned_beliefs
        self.item_info = {
            '0': {'name':'red dragon scale mail', 'symbol':'['},
            '1': {'name':'leather jacket', 'symbol':'['},
            '2': {'name':'leather armor', 'symbol':'['},
        }

        if self.obl_correction and self.use_learned_beliefs:
            self.belief_dist = {
                'item_present_index': np.ones(len(self.item_info))/len(self.item_info)
            }
        else: # assume ground truth beliefs
            belief_dist = np.array(p)
            self.belief_dist = {
                'item_present_index': belief_dist/belief_dist.sum()
            }

        self.goal_hint_p = goal_hint_p
        self.reward_dist = reward_dist
        self.item_present_index = None

        self.last_message = ''

        self._init_empty_grid(seed=seed)

        self.lvl_gen = LevelGenerator(w=1, h=1, lit=lit)
        self.reward_manager = RewardManager()

        super().__init__(*args, 
            des_file=self.lvl_gen.get_des(), **kwargs)

        self.penalty_step = penalty_step
        kwargs["penalty_step"] = penalty_step

        self.seed(seed) # sets np_random seed

        self.unseeded_np_random,_ = seeding.np_random()

        self._regenerate_level()

    def _reset_episodic_counts(self):
        self.episodic_counts = {
            'wear_armor': 0,
            'wear_fireproof': 0,
            'monster_kills': 0,
        }

    def _init_empty_grid(self, seed=None):
        """
        Grid is initially the empty room layout.
        """
        h,w = minihack_grid.Grid.fixed_grid_str_dim(self.fixed_grid_str)
        self.grid = minihack_grid.Grid(
            width=w, 
            height=h,
            fixed_grid_str=self.fixed_grid_str,
            seed=seed)

        self.grid.add_custom_objects({
            k: {'name':info['name'], 'symbol':info['symbol']} for k, info in self.item_info.items()})

        self.minion_name = 'pyrolisk'
        self.minion_symbol = 'c'

        self.grid.add_custom_monsters({
            'j': {'name': self.minion_name, 'symbol': self.minion_symbol}
        })

        h,w = self.grid.map.shape

        self.room_mask = self.grid.mask_rect((1,1), (w-2,h-2))

        # === Update rewards ===
        # Should consider supporting a dense reward version of this task
        self.reward_manager = RewardManager()

    def _regenerate_level(self):
        """
        Generate n_clutter walls and randomly place the agent and goals (an apple and banana)

        Note the goals are two items, with auto-pickup on.
        """
        self._reset_episodic_counts()

        # === Generate topology ===
        self.grid.clear()

        # Surround with walls

        h,w = self.grid.map.shape

        n_lava = self.n_lava
        n_wall = self.n_wall
        n_monster = self.n_monster
        if self.generator_dist == 'uniform':
            if self.n_lava > 0:
                n_lava = self.np_random.randint(self.n_lava) + 1
            if self.n_wall > 0:
                n_wall = self.np_random.randint(self.n_wall) + 1
            if self.n_monster > 0:
                n_monster = self.np_random.randint(self.n_monster) + 1

        # Place lava
        for i in range(n_lava):
            self.grid.set_char('L', loc='random_free', mask=~self.room_mask)

        # Place walls
        for i in range(n_wall):
            self.grid.set_char('-', loc='random_free', mask=~self.room_mask)

        # Place monsters in right room
        for i in range(n_monster):
            self.grid.set_char('j', 
                loc='random_free', 
                mask=~self.room_mask,
                unseeded=self.randomize_monsters)

        # Place agent
        self.grid.set_char('<', loc='random_free', mask=~self.room_mask)

        # Place goal
        self.grid.set_char('>', loc='random_free', mask=~self.room_mask)

        # Generate armor with probability p
        # Choose an item index
        item_dist = self.p
        item_options = range(len(item_dist))
        if self.obl_correction:
            if self.use_learned_beliefs:
                item_dist = self.belief_dist['item_present_index']
            self.item_present_index = self.unseeded_np_random.choice(item_options, p=item_dist)

            # print('sampled armor from dist', item_dist)
        else:
            self.item_present_index = self.np_random.choice(item_options, p=item_dist)

        item_char = list(self.item_info)[self.item_present_index]

        # Place item midway to the bottom corner
        self.grid.set_char(item_char, loc='random_free', mask=~self.room_mask)

        # === Compute shortest path metrics ===
        clutter_chars = []
        if self.n_lava > 0:
            clutter_chars.append('L')
        if self.n_wall > 0:
            clutter_chars.append('-')
        if self.n_monster > 0:
            clutter_chars.append('j')
        grid_info = \
            self.grid.get_metrics(
                goal_chars=['>',] + [k for k in self.item_info],
                clutter_chars=clutter_chars,
                aliases={
                    '>': 'goal',
                    '-': 'wall',
                    'L': 'lava',
                    'j': 'monster',
                    **{k:'armor' for k in self.item_info.keys()}
                })
        self.passable = grid_info['passable']
        self.shortest_path_length = grid_info['shortest_path_lengths']
        self.n_clutter_placed = grid_info['clutter_counts']

        # cprint(DEBUG, grid_info)
        # print(f'SEED:{self._level_seed}')
        # print(self.grid.map)
        # print('des file is\n')
        # print(self.grid.des)
        # import pdb; pdb.set_trace()

    def _wizard_accept_death(self, cur_message, r):
        if self.fully_observable:
            match_die = re.match(r"die\?", cur_message)
            if match_die:
                _actions_tmp = self._actions
                self._actions = FULL_ACTIONS
                obs, _, done, info = super().step(_FULL_ACTION2INDEX[ord('y')])
                self._actions = _actions_tmp

                if done:
                    info['episodic_counts'] = self.episodic_counts
                return obs, r, done, info
        else:
            return False

    def step(self, action):
        # if self.fully_observable:
            # return self.step_fully_observable(action)
        obs, r, done, info = super().step(action)

        cur_message = obs['message'].tobytes().decode('utf-8').lower()
        died = self._wizard_accept_death(cur_message, r)
        if died: return died

        match_armor_loc_re = r"you see here .* (scale mail|jacket|armor)"

        match_armor_loc = re.match(match_armor_loc_re, cur_message)
        if match_armor_loc:
            self.autowear_interrupted = True
            _actions_tmp = self._actions
            self._actions = FULL_ACTIONS
            obs, _, done, info = super().step(_FULL_ACTION2INDEX[nethack.Command.PICKUP])
            self._actions = _actions_tmp
            cur_message = obs['message'].tobytes().decode('utf-8').lower()
            died = self._wizard_accept_death(cur_message, r)
            if died: return died

            match_armor_pickup = re.match(r"^([a-z]) - .* (scale mail|jacket|armor)", cur_message)
            if match_armor_pickup:
                self.armor_inventory_index = match_armor_pickup[1]
                if self.armor_inventory_index is not None:
                    wear_actions = f'W{self.armor_inventory_index}'
                    for a in wear_actions:
                        if done: break
                        _actions_tmp = self._actions
                        self._actions = FULL_ACTIONS
                        obs, _, done, info = super().step(_FULL_ACTION2INDEX[ord(a)])
                        self.autowear_interrupted = False
                        self._actions = _actions_tmp

                        cur_message = obs['message'].tobytes().decode('utf-8').lower()
                        died = self._wizard_accept_death(cur_message, r)
                        if died: return died

                        # Track successful wear
                        if a == self.armor_inventory_index:
                            r += 1.0
                            self.episodic_counts['wear_armor'] = 1

                            if self.item_present_index == 0:
                                self.episodic_counts['wear_fireproof'] = 1

        elif not done and self.autowear_interrupted:
            _actions_tmp = self._actions
            self._actions = FULL_ACTIONS
            obs, _, done, info = super().step(_FULL_ACTION2INDEX[ord('W')])
            cur_message = obs['message'].tobytes().decode('utf-8').lower()
            died = self._wizard_accept_death(cur_message, r)
            if died: return died

            match_armor = re.match(r"what do you want to wear\? \[(.+) or \?\*\]", cur_message)
            if not done and match_armor:
                self.armor_inventory_index = match_armor[1]
                if len(self.armor_inventory_index) > 1:
                    self.armor_inventory_index = self.armor_inventory_index[0]
                obs, _, done, info = super().step(_FULL_ACTION2INDEX[
                    ord(self.armor_inventory_index)
                ])
                cur_message = obs['message'].tobytes().decode('utf-8').lower()
                died = self._wizard_accept_death(cur_message, r)
                if died: return died
                
            self._actions = _actions_tmp

        elif f'you kill the {self.minion_name}!' in cur_message:
            if self.reward_kill:
                r += 1.0

            self.episodic_counts['monster_kills'] += 1

        self.last_message = cur_message

        if done:
            end_status = info.get('end_status', None)
            if end_status is not None and end_status.name == 'TASK_SUCCESSFUL':
                self.episodic_counts['goal_reached'] = 1

            info['episodic_counts'] = self.episodic_counts

            # print('EPISODIC COUNTS:', info['episodic_counts'])
            # print('AUX PROPERTIES:', self.aux_properties)

        return obs, r, done, info

    def reset(self):
        self.autowear_interrupted = False

        self._regenerate_level()
        super().update(des_file=self.grid.des)

        if self.fully_observable:
            super().reset()

            for c in "#wizintrinsic\rt\r\r#wizmap\r#wizwish\ra potion of object detection\r":
                obs, sds = self.env.step(ord(c))
            msg = (
                obs[self._original_observation_keys.index("message")]
                .tobytes()
                .decode('utf-8')
            )

            for c in f"q{msg[0]}":
                obs, sds = self.env.step(ord(c))

            _actions_tmp = self._actions
            self._actions = FULL_ACTIONS
            obs, _, _, _ = super().step(self._actions.index(nethack.MiscDirection.WAIT))
            self._actions = _actions_tmp
        else:
            obs = super().reset()

        return obs

    def reset_agent(self):
        super().update(des_file=self.grid.des)
        return super().reset()

    def seed(self, level_seed, *args, **kwargs):
        super().seed(level_seed, *args, **kwargs)

        self._level_seed = level_seed

        self.np_random, _ = seeding.np_random(level_seed)
        self.grid.seed(level_seed)

    @property
    def aux_properties(self):
        return {
            'fireproof_armor': self.item_present_index == 0,
            'belief_fireproof': self.belief_dist['item_present_index'][0]
        }

    @property
    def belief_spec(self):
        return {
            'item_present_index': {
                'type': 'categorical',
                'size': len(self.item_info)
            }
        }

    @property
    def belief_tokens(self):
        num_item_options = len(self.item_info)
        item_one_hot = np.zeros(num_item_options, dtype=np.int32)
        sample_index = self.np_random.choice(range(num_item_options), p=self.p)
        item_one_hot[sample_index] = 1

        tokens = {
            'item_present_index': item_one_hot
        }
        return tokens

    def set_belief_dist(self, belief_dist):
        self.belief_dist = belief_dist

    @property
    def des_file(self):
        return self.grid.des

    @property
    def grid_str(self):
        return self.grid.map.__str__()

    def render(self, mode='level'):
        if mode == 'level':
            des_file = self.grid.des
            return np.asarray(get_des_file_rendering(des_file, wizard=True))
        else:
            return super().render(mode=mode)


class MiniHackSixteenRooms(MiniHackCustomMaze):
    def __init__(
        self,
        *args,
        **kwargs
    ):

        fixed_grid_str = """
...-..-..-...
.........-...
...-..-......
-.---.--.---.
...-.........
......-..-...
--.-.--.---.-
...-.....-...
...-..-......
.----.--.-.--
...-..-..-...
......-......
...-.....-...
        """

        super().__init__(
            *args,
            fixed_grid_str=fixed_grid_str,
            **kwargs
        )

class MiniHackLabyrinth(MiniHackCustomMaze):
    def __init__(
        self,
        *args,
        **kwargs
    ):

        fixed_grid_str = """
.............
.-----------.
.-.........-.
.-.-------.-.
.-.-.....-.-.
.-.-.---.-.-.
.-.-.-.-.-.-.
.-.-.-.-.-.-.
.-...-...-.-.
.---------.-.
.....-.....-.
----.-.-----.
.....-.......
        """

        super().__init__(
            *args,
            fixed_grid_str=fixed_grid_str,
            **kwargs
        )

class MiniHackMaze(MiniHackCustomMaze):
    def __init__(
        self,
        *args,
        **kwargs
    ):

        fixed_grid_str = """
.....-....-..
.---.----.--.
.-...........
.--------.---
........-....
------.-----.
....-..-.....
.--...--.----
..-.-..-...-.
-.-.--.---.-.
-.-..-...-...
-.--.---.---.
...-...-.-...
        """

        super().__init__(
            *args,
            fixed_grid_str=fixed_grid_str,
            **kwargs
        )


if hasattr(__loader__, 'name'):
  module_path = __loader__.name
elif hasattr(__loader__, 'fullname'):
  module_path = __loader__.fullname


registration.register(
    id='MiniHack-SixteenRooms-v0',
    entry_point=module_path + ':MiniHackSixteenRooms',
    max_episode_steps=250
)

registration.register(
    id='MiniHack-Labyrinth-v0',
    entry_point=module_path + ':MiniHackLabyrinth',
    max_episode_steps=250
)

registration.register(
    id='MiniHack-Maze-v0',
    entry_point=module_path + ':MiniHackMaze',
    max_episode_steps=250
)