from collections import defaultdict

import numpy as np
import torch
from gym.utils import seeding
import networkx as nx

from minihack import LevelGenerator


EASY_MONSTER_NAMES = [
    # easy
    'jackal', 
    'sewer rat', 
    'grid bug', 
    'kobold zombie', 
    'newt',

    # hard
    # 'jabberwock',
]

WEAPON_NAMES = [
    'wand',
]

TERRAIN_CHARS = [
    ' ', # empty
    '.', # floor
    '-', # horizontal wall
    '|', # vertical wall
    '#', # corridor
    'L', # lava,
    'I', # ice,
    '±', # tree
]

NON_TERRAIN_CHARS = [
    '+', # locked door
    'd', # closed door
    'm', # monster,
    'i', # item,
]

OBSTACLE_TERRAIN_CHARS = set([
    ' ', # empty
    '-', # horizontal wall
    '|', # vertical wall
    'L' # lava,
    '±', # tree
])

class Grid(object):
    """
    Simple class to wrap the MiniHack level generator.
    This reduces the amount of string manipulation we need to do in adversarial.py
    """
    def __init__(self, 
            width, 
            height, 
            lit=True, 
            fixed_grid_str=None,
            diag_paths=True, 
            seed=None):
        self.TERRAIN_CHARS = set(TERRAIN_CHARS)
        self.NON_TERRAIN_CHARS = set(NON_TERRAIN_CHARS)
        self.ALL_CHARS = set(
            TERRAIN_CHARS + 
            NON_TERRAIN_CHARS + 
            ['>', '<']
        )
        self.lit = lit
        self.diag_paths = diag_paths

        # Custom objects and monsters
        self.custom_object_info = {}
        self.custom_object_chars = set()
        self.custom_monster_info = {}
        self.custom_monster_chars = set()

        self.width = width
        self.height = height

        self.fixed_grid_str = fixed_grid_str

        self.seed(seed)
        self.unseeded_np_random,_ = seeding.np_random()

        self.clear()

    def seed(self, seed):
        self.np_random, _ = seeding.np_random(seed)

    def get_grid_str(self):
        return self.map.__str__()

    def add_custom_objects(self, char2info):
        char_set = self.NON_TERRAIN_CHARS
        for char,info in char2info.items():
            char_set.add(char)
            self.custom_object_chars.add(char)
        self.ALL_CHARS |= self.custom_object_chars
        self.custom_object_info = char2info

    def add_custom_monsters(self, char2info):
        char_set = self.NON_TERRAIN_CHARS
        for char,info in char2info.items():
            char_set.add(char)
            self.custom_monster_chars.add(char)
        self.ALL_CHARS |= self.custom_monster_chars
        self.custom_monster_info = char2info

    def _get_random_loc(self, mask=None, unseeded=False):
        if mask is None:
            if unseeded:
                x = self.unseeded_np_random.randint(self.width)
                y = self.unseeded_np_random.randint(self.height)
            else:
                x = self.np_random.randint(self.width)
                y = self.np_random.randint(self.height)
        else:
            return self._get_random_loc_from_mask(mask, unseeded)
        return x,y

    def _get_random_loc_free(self, mask=None, unseeded=False):
        if mask is None:
            return self._get_random_loc_from_mask(self.free_coords, unseeded)
        else:
            mask_ = self.free_coords & mask
            return self._get_random_loc_from_mask(mask_, unseeded)

    def _get_random_loc_from_mask(self, mask, unseeded=False):
        p = mask.flatten().astype(np.float32)
        z = p.sum()

        assert z > 0, 'No more free tiles in unmasked tiles'

        if unseeded:
            idx = self.unseeded_np_random.choice(range(len(p)), 1, p=p/z)
        else:
            idx = self.np_random.choice(range(len(p)), 1, p=p/z)
        x = idx % self.width
        y = idx // self.width

        return int(x[0]),int(y[0])

    def mask_neighbors(self, char=None, loc=[], first=False, moore=True):
        assert char is not None or len(loc) > 0, 'Must specify either char or locations to mask.'

        mask = np.ones_like(self._map, dtype=np.bool)
        char_mask = np.ones_like(self._map, dtype=np.bool)

        for y, row in enumerate(self._map):
            stop = False
            for x, map_char in enumerate(row):
                if map_char == char or (x,y) in loc:
                    
                    char_mask[y][x] = False

                    if moore: # Indicate Moore neighborhood
                        tl_x = max(x-1,0)
                        tl_y = max(y-1,0)
                        br_x = min(x+1,self.width-1)
                        br_y = min(y+1,self.height-1)
                        mask[tl_y:br_y+1,tl_x:br_x+1] = False
                    else: # Indicate Von Neumann neighborhood
                        mask[max(y-1,0),x] = False # top
                        mask[y,min(x+1,self.width-1)] = False # right
                        mask[min(y+1,self.height-1),x] = False # bottom
                        mask[y,max(x-1,0)] = False # left
                    if first:
                        stop = True
                        break
            if stop:
                break
        
        return mask | ~char_mask # unmask chars

    def mask(self, char, first=False):
        mask = np.ones_like(self._map, dtype=np.bool)
        for y, row in enumerate(self._map):
            stop = False
            for x, map_char in enumerate(row):
                if map_char == char:
                    mask[y][x] = False

                    if first:
                        stop = True
            if stop:
                break

        return mask

    def mask_rect(self, top_left, bottom_right):
        tl_x,tl_y = top_left
        br_x, br_y = bottom_right

        max_y, max_x = self._map.shape
        max_y -= 1
        max_x -= 1

        tl_x = min(max(tl_x, 0), max_x)
        tl_y = min(max(tl_y, 0), max_y)

        br_x = min(max(br_x, 0), max_x)
        br_y = min(max(br_y, 0), max_y)

        assert (br_x >= tl_x) and (br_y >= tl_y), 'top_left must be <= bottom_right'

        mask = np.ones_like(self._map, dtype=np.bool)
        for y in range(tl_y, br_y+1):
            for x in range(tl_x, br_x+1):
                mask[y][x] = False

        return mask

    @staticmethod
    def fixed_grid_str_dim(fixed_grid_str):
        rows = fixed_grid_str.strip().split('\n')
        width = np.max([len(r) for r in rows])
        height = len(rows)

        return height, width

    def clear(self, default_char=None):
        # Set width / height to match fixed_grid_str
        rows = None
        if self.fixed_grid_str:
            rows = self.fixed_grid_str.rstrip('\n').split('\n')
            if len(rows[0]) == 0:
                rows = rows[1:]

            width = np.max([len(r) for r in rows])
            height = len(rows)
            self.width = width
            self.height = height

        self.lvl_gen = LevelGenerator(
            map=None, 
            w=self.width, h=self.height, 
            lit=True, 
            flags=("premapped",))

        self.nonterrain_coords = defaultdict(set)

        self._map = np.copy(self.lvl_gen.get_map_array()) # H x W
        self._map.fill(default_char)

        self.agent_start_loc = None
        self.goal_locs = set()
        self.free_coords = np.ones((self._map.shape), dtype=np.uint8)

        self.fenced = False

        self.grid_graph = nx.grid_graph(dim=[self.height, self.width])
        if self.diag_paths:
            diags = [
                ((x, y), (x+1, y+1))
                for x in range(self.width-1)
                for y in range(self.height-1)
            ] + [
                ((x+1, y), (x, y+1))
                for x in range(self.width-1)
                for y in range(self.height-1)
            ]
            self.grid_graph.add_edges_from(diags)

        if self.fixed_grid_str:
            self._map.fill(' ')
            for y, row in enumerate(rows):
                for x, char in enumerate(row):
                    self._map[y][x] = char

                    if char is None:
                        self.grid_graph.remove_node((x,y))

            for y, row in enumerate(self._map):
                for x, char in enumerate(row):
                    self.set_char(char, loc=(x,y))

        self._dirty_map = False
        self._dirty_des = True

    def set_char(self, char='.', loc='random', mask=None, unseeded=False):
        assert char in self.ALL_CHARS, f'Character {char} is not supported.'

        if loc == 'random':
            x,y = self._get_random_loc(mask=mask, unseeded=unseeded)
        elif loc == 'random_free':
            x,y = self._get_random_loc_free(mask=mask, unseeded=unseeded)
        else:
            assert isinstance(loc, tuple), f'Loc must be an x,y tuple or "random".'
            x,y = loc

            assert x < self.width, f'x {x} must be less than width={self.width}.'
            assert y < self.height, f'y {y} must be less than width={self.height}.'

        # Clear any non-terrain or obstacle terrain character from x,y
        cur_char = self._map[y][x]
        if cur_char in self.NON_TERRAIN_CHARS:
            self.nonterrain_coords[cur_char].remove((x, y))
            self._map[y][x] = '.'
            self.free_coords[y][x] = 1

        if cur_char in OBSTACLE_TERRAIN_CHARS:
            self._map[y][x] = '.'
            self.free_coords[y][x] = 1

        if (x,y) in self.goal_locs:
            self.goal_locs.remove((x,y))
            self.free_coords[y][x] = 1

        if (x,y) == self.agent_start_loc:
            if cur_char in OBSTACLE_TERRAIN_CHARS:
                self.agent_start_loc = None
                self.free_coords[y][x] = 1

        # === Add new tiles ===
        if char in self.TERRAIN_CHARS:
            self.lvl_gen.add_terrain(coord=(x,y), flag=char)
            if char in OBSTACLE_TERRAIN_CHARS:
                self.free_coords[y][x] = 0

        elif char in self.NON_TERRAIN_CHARS:
            self.lvl_gen.add_terrain(coord=(x,y),flag='.')
            self.nonterrain_coords[char].add((x,y))
            self.free_coords[y][x] = 0

        # Place goal
        elif char == '>':
            self.lvl_gen.add_terrain(coord=(x, y), flag='.')
            self.goal_locs.add((x, y))
            self.free_coords[y][x] = 0

        # Place agent
        elif char == '<':
            self.lvl_gen.add_terrain(coord=(x, y), flag='.')
            self.agent_start_loc = (x,y)
            self.free_coords[y][x] = 0

        self._map[y][x] = char

        # Refresh map
        self._dirty_map = True
        self._dirty_des = True

        return x,y

    def get_char(self, loc):
        x,y = loc
        return self._map[y][x]

    def fence(self, char='-'):
        h, w = self._map.shape
        for y in range(h):
            self.set_char(char, loc=(0,y))
            self.set_char(char, loc=(w-1,y))

        for x in range(w):
            self.set_char(char, loc=(x,0))
            self.set_char(char, loc=(x,h-1))

        self.fenced = True

    def _refresh_map(self):
        if self._dirty_map:
            ## Add monsters etc to the map with a placeholder
            for char, coords in self.nonterrain_coords.items():
                for x,y in coords:
                    self._map[y][x] = char

            for (x,y) in self.goal_locs:
                self._map[y][x] = '>'

            if self.agent_start_loc:
                x,y = self.agent_start_loc
                self._map[y][x] = '<'

            self._dirty_map = False

    @property
    def map(self):
        # Ignore hard refresh in favor of inline _map updates
        # if self._dirty_map:
        #     self._refresh_map()
        return self._map

    def compile_des(self):
        ## Add footer info for nonterrain chars
        terrain_map = self.lvl_gen.get_map_str()

        self.lvl_gen = LevelGenerator(
            map=terrain_map, 
            w=self.width, h=self.height, 
            lit=self.lit, 
            flags=("premapped",))

        for char, coords in self.nonterrain_coords.items():
            if char == 'm':
                for i, (x,y) in enumerate(coords):
                    self.lvl_gen.add_monster(
                        name=EASY_MONSTER_NAMES[i % (len(EASY_MONSTER_NAMES))],
                        place=(x,y))

            elif char == '+':
                for i, (x,y) in enumerate(coords):
                    self.lvl_gen.add_door(
                        state="locked",
                        place=(x,y))

            elif char == 'd':
                for i, (x,y) in enumerate(coords):
                    self.lvl_gen.add_door(
                        state="closed",
                        place=(x,y))

            elif char in self.custom_object_chars:
                char2info = self.custom_object_info
                for i, (x,y) in enumerate(coords):
                    self.lvl_gen.add_object(
                        place=(x,y),
                        name=char2info[char]['name'],
                        symbol=char2info[char]['symbol'],
                        cursestate=char2info[char].get('cursestate')
                    )

            elif char in self.custom_monster_chars:
                char2info = self.custom_monster_info
                for i, (x,y) in enumerate(coords):
                    self.lvl_gen.add_monster(
                        place=(x,y),
                        name=char2info[char]['name'],
                        symbol=char2info[char]['symbol']
                    )

        ## Add the staircases (up == agent start, down == goal)
        if self.agent_start_loc is None:
            self.agent_start_loc = (0,0)

        x,y = self.agent_start_loc
        self.lvl_gen.add_stair_up((x,y))

        for x,y in self.goal_locs:
            self.lvl_gen.add_stair_down((x,y))

        self._des = self.lvl_gen.get_des()

        self._dirty_des = False

    @property
    def des(self):
        if self._dirty_des:
            self.compile_des()

        return self._des

    def get_map_char(self, x, y):
        return self.map[y][x]

    def get_metrics(
            self, 
            goal_chars=['>'], 
            clutter_chars=['-'],
            aliases={}):
        goal_chars = set(goal_chars)
        goal_distances = defaultdict(list)

        clutter_chars = set(clutter_chars)
        clutter_counts = {
            aliases.get(char,char):0 for char in clutter_chars}

        if self.fenced and '-' in clutter_chars:
            clutter_counts[aliases.get('-','-')] \
                -= (2*np.sum(self.map.shape) - 4)

        agent_x, agent_y = self.agent_start_loc

        G = nx.Graph.copy(self.grid_graph)

        map_ = self.map
        for y,row in enumerate(map_):
            for x,_ in enumerate(row):
                char = map_[y][x]
                key = aliases.get(char, char)
                if char in OBSTACLE_TERRAIN_CHARS:
                    G.remove_node((x,y))
                if char in clutter_chars:
                    clutter_counts[key] += 1
        
        for y,row in enumerate(map_):
            for x,_ in enumerate(row):
                char = map_[y][x]
                key = aliases.get(char, char)
                if char in goal_chars:
                    passable = nx.has_path(
                        G,
                        source=(agent_x,agent_y),
                        target=(x,y)
                    )
                    if passable:
                        distance = nx.shortest_path_length(
                            G,
                            source=(agent_x,agent_y),
                            target=(x,y),
                        )
                        goal_distances[key] += [distance]
                    else:
                        goal_distances[key] += [-1,]

        passable = {}
        shortest_path_lengths = {}
        for char, distances in goal_distances.items():
            key = aliases.get(char, char)
            np_distances = np.array(distances)
            pass_distances = np_distances[np_distances >= 0]
            if len(pass_distances) == 0:
                passable[key] = False
                shortest_path_lengths[key] = 0
            else:
                passable[key] = True
                shortest_path_lengths[key] = \
                    np.mean(pass_distances)

        info = {
            'passable': passable,
            'shortest_path_lengths': shortest_path_lengths,
            'clutter_counts': clutter_counts
        }

        return info

    def __str__(self):
        return '\n'.join([' '.join(row) for row in self.map])