# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum

import numpy as np

from nle.nethack import MAX_GLYPH, GLYPH_MON_OFF, GLYPH_PET_OFF, GLYPH_INVIS_OFF, GLYPH_DETECT_OFF, GLYPH_BODY_OFF, GLYPH_RIDDEN_OFF, GLYPH_OBJ_OFF, GLYPH_CMAP_OFF, GLYPH_EXPLODE_OFF, GLYPH_ZAP_OFF, GLYPH_SWALLOW_OFF, GLYPH_WARNING_OFF, GLYPH_STATUE_OFF, EXPL_MAX, NUM_ZAP
from nle.nethack import objclass, permonst, NUMMONS

# flake8: noqa: F405

# TODO: import this from NLE again
NUM_OBJECTS = 453
MAXEXPCHARS = 9

M1_POIS = 0x10000000
M1_ACID = 0x08000000
M2_WHERE = 0x00000004
M2_HUMAN = 0x00000008

def _get_unique_levels():
    levels = []
    difficulties = []
    speeds = []
    corpse_editable = []
    sacrifice = []
    for i in range(NUMMONS):
        temp = permonst(i)
        levels.append(temp.mlevel)
        difficulties.append(temp.difficulty)
        speeds.append(temp.mmove)
        corpse_editable.append(int(
            (temp.mflags1 & M1_POIS == 0) and
            (temp.mflags1 & M1_ACID == 0) and
            (temp.mflags2 & M2_WHERE == 0) and
            (temp.mflags2 & M2_HUMAN == 0)
        ))
        sacrifice.append(int(
            temp.mflags2 & (M2_HUMAN | M2_WHERE) == 0
        ))

    return np.unique(levels), np.unique(difficulties), np.unique(speeds), np.unique(corpse_editable), np.unique(sacrifice)


UNIQUE_LEVELS, UNIQUE_DIFFICULTIES, UNIQUE_SPEEDS, UNIQUE_EDITABLE, UNIQUE_SACRIFICE = _get_unique_levels()


def _get_permonst_level_index(index):
    return np.where(UNIQUE_LEVELS == permonst(index).mlevel)[0][0]

def _get_permonst_difficulty_index(index):
    return np.where(UNIQUE_DIFFICULTIES == permonst(index).difficulty)[0][0]

def _get_permonst_speed_index(index):
    return np.where(UNIQUE_SPEEDS == permonst(index).mmove)[0][0]

def _get_permonst_editable_index(index):
    return int(
        (permonst(index).mflags1 & M1_POIS == 0) and
        (permonst(index).mflags1 & M1_ACID == 0) and
        (permonst(index).mflags2 & M2_WHERE == 0) and
        (permonst(index).mflags2 & M2_HUMAN == 0)
    )

def _get_permonst_sacrifice_index(index):
    return int(
        permonst(index).mflags2 & (M2_HUMAN | M2_WHERE) == 0
    )

class GlyphGroup(enum.IntEnum):
    """Short summary.

    Attributes
    ----------
    MON : type
        Description of attribute `MON`.
    PET : type
        Description of attribute `PET`.
    INVIS : type
        Description of attribute `INVIS`.
    DETECT : type
        Description of attribute `DETECT`.
    BODY : type
        Description of attribute `BODY`.
    RIDDEN : type
        Description of attribute `RIDDEN`.
    OBJ : type
        Description of attribute `OBJ`.
    CMAP : type
        Description of attribute `CMAP`.
    EXPLODE : type
        Description of attribute `EXPLODE`.
    ZAP : type
        Description of attribute `ZAP`.
    SWALLOW : type
        Description of attribute `SWALLOW`.
    WARNING : type
        Description of attribute `WARNING`.
    STATUE : type
        Description of attribute `STATUE`.

    """
    # See display.h in NetHack.
    MON = 0
    PET = 1
    INVIS = 2
    DETECT = 3
    BODY = 4
    RIDDEN = 5
    OBJ = 6
    CMAP = 7
    EXPLODE = 8
    ZAP = 9
    SWALLOW = 10
    WARNING = 11
    STATUE = 12


def id_pairs_table():
    """Short summary.

    Returns
    -------
    ndarray
        Returns an id pairs table. Each column corresponds to a separate
        attribute as described below.

    +---------------------------------------+-------+--------------------+---------------+---------------+---------------+------------------+-------------------+
    | Glyph                                 | Group | Monster Difficulty | Monster Level | Object Weight | Monster Speed | Monster Editable | Monster Sacrifice |
    +---------------------------------------+-------+--------------------+---------------+---------------+---------------+------------------+-------------------+
    | Corresponds to the Glyph minus offset |       |                    |               |               |               |                  |                   |
    +---------------------------------------+-------+--------------------+---------------+---------------+---------------+------------------+-------------------+

    """
    """Returns a lookup table for glyph -> NLE id pairs."""
    table = np.zeros([MAX_GLYPH + 1, 8], dtype=np.int16)

    num_nle_ids = 0

    for glyph in range(GLYPH_MON_OFF, GLYPH_PET_OFF):
        table[glyph] = (glyph, GlyphGroup.MON, _get_permonst_difficulty_index(glyph),
                        _get_permonst_level_index(glyph), 0, _get_permonst_speed_index(glyph), 
                        _get_permonst_editable_index(glyph), _get_permonst_sacrifice_index(glyph))
        num_nle_ids += 1

    for glyph in range(GLYPH_PET_OFF, GLYPH_INVIS_OFF):
        table[glyph] = (glyph - GLYPH_PET_OFF, GlyphGroup.PET,
                        _get_permonst_difficulty_index(glyph - GLYPH_PET_OFF),
                        _get_permonst_level_index(glyph - GLYPH_PET_OFF),
                        0,
                        _get_permonst_speed_index(glyph - GLYPH_PET_OFF),
                        _get_permonst_editable_index(glyph - GLYPH_PET_OFF),
                        _get_permonst_sacrifice_index(glyph - GLYPH_PET_OFF))

    for glyph in range(GLYPH_INVIS_OFF, GLYPH_DETECT_OFF):
        table[glyph] = (num_nle_ids, GlyphGroup.INVIS, 0, 0, 0, 0, 0, 0)
        num_nle_ids += 1

    for glyph in range(GLYPH_DETECT_OFF, GLYPH_BODY_OFF):
        table[glyph] = (glyph - GLYPH_DETECT_OFF, GlyphGroup.DETECT, 0, 0, 0, 0, 0, 0)

    for glyph in range(GLYPH_BODY_OFF, GLYPH_RIDDEN_OFF):
        table[glyph] = (glyph - GLYPH_BODY_OFF, GlyphGroup.BODY, 0, 0, 0, 0, 0, 0)

    for glyph in range(GLYPH_RIDDEN_OFF, GLYPH_OBJ_OFF):
        table[glyph] = (glyph - GLYPH_RIDDEN_OFF, GlyphGroup.RIDDEN, 0, 0, 0, 0, 0, 0)

    for glyph in range(GLYPH_OBJ_OFF, GLYPH_CMAP_OFF):
        table[glyph] = (num_nle_ids, GlyphGroup.OBJ, 0, 0,
                        objclass(glyph - GLYPH_OBJ_OFF).oc_weight, 0, 0, 0)
        num_nle_ids += 1

    for glyph in range(GLYPH_CMAP_OFF, GLYPH_EXPLODE_OFF):
        table[glyph] = (num_nle_ids, GlyphGroup.CMAP, 0, 0, 0, 0, 0, 0)
        num_nle_ids += 1

    for glyph in range(GLYPH_EXPLODE_OFF, GLYPH_ZAP_OFF):
        id_ = num_nle_ids + (glyph - GLYPH_EXPLODE_OFF) // MAXEXPCHARS
        table[glyph] = (id_, GlyphGroup.EXPLODE, 0, 0, 0, 0, 0, 0)

    num_nle_ids += EXPL_MAX

    for glyph in range(GLYPH_ZAP_OFF, GLYPH_SWALLOW_OFF):
        id_ = num_nle_ids + (glyph - GLYPH_ZAP_OFF) // 4
        table[glyph] = (id_, GlyphGroup.ZAP, 0, 0, 0, 0, 0, 0)

    num_nle_ids += NUM_ZAP

    for glyph in range(GLYPH_SWALLOW_OFF, GLYPH_WARNING_OFF):
        table[glyph] = (num_nle_ids, GlyphGroup.SWALLOW, 0, 0, 0, 0, 0, 0)
    num_nle_ids += 1

    for glyph in range(GLYPH_WARNING_OFF, GLYPH_STATUE_OFF):
        table[glyph] = (num_nle_ids, GlyphGroup.WARNING, 0, 0, 0, 0, 0, 0)
        num_nle_ids += 1

    for glyph in range(GLYPH_STATUE_OFF, MAX_GLYPH):
        table[glyph] = (glyph - GLYPH_STATUE_OFF, GlyphGroup.STATUE, 0, 0, 0, 0, 0, 0)

    return table


# # test
# table = id_pairs_table()
# for glyph in range(GLYPH_MON_OFF, GLYPH_PET_OFF):
#     print(f"{permonst(glyph).mname}: {table[glyph][-2]}")
# print(UNIQUE_EDITABLE)
# print(UNIQUE_SACRIFICE)
