import hashlib
import random
import re
import warnings
from collections import Counter
from functools import lru_cache

import gym
import numpy as np
import torch
from minihack.agent.common.envs.wrapper import CachedEnvWrapper
from minihack.base import MH_FULL_ACTIONS
from nle import nethack
from tokenizers import BertWordPieceTokenizer

TEMPLATE_RE = re.compile(r"\"([^\"]+)\"")


MATERIALS = {
    b"bubbly",
    b"brilliant",
    b"jungle",
    b"yellow",
    b"orange",
    b"puce",
    b"smoky",
    b"magenta",
    b"granite",
    b"opal",
    b"onyx",
    b"wooden",
    b"engagement",
    b"ruby",
    b"balsa",
    b"jade",
    b"shiny",
    b"purple-red",
    b"sky",
    b"spiked",
    b"tin",
    b"runed",
    b"glass",
    b"aluminum",
    b"iron",
    b"short",
    b"silver",
    b"uranium",
    b"ebony",
    b"marble",
    b"iridium",
    b"zinc",
    b"long",
    b"hexagonal",
    b"steel",
    b"oak",
    b"copper",
    b"jeweled",
    b"curved",
    b"forked",
    b"maple",
    b"brass",
    b"platinum",
    b"pine",
    b"crystal",
    b"agate",
    b"black",
    b"white",
    b"red",
    b"green",
    b"blue",
    b"gold",
    b"purple",
    b"dark",
    b"riding",
    b"cloudy",
    b"brown",
    b"mud",
    b"golden",
    b"swirly",
    b"hiking",
    b"diamond",
    b"snow",
    b"buckled",
    b"clay",
    b"combat",
    b"bronze",
    b"cyan",
    b"pink",
    b"milky",
    b"effervescent",
    b"gold",
    b"wire",
    b"topaz",
    b"tiger",
    b"eye",
    b"wire",
    b"fizzy",
    b"dark",
    b"sapphire",
    b"coral",
    b"emerald",
    b"twisted",
    b"brilliant",
    b"moonstone",
    b"ivory",
    b"murky",
    b"snow",
    b"pearl",
}

MONSTERS = {
    b"jackal",
    b"fox",
    b"lichen",
    b"grid",
    b"bug",
    b"zombie",
    b"kobold",
    b"newt",
    b"goblin",
    b"sewer",
    b"rat",
}


MONSTERS_E = {m + b"!" for m in MONSTERS}
MONSTERS_P = {m + b"." for m in MONSTERS}


MOVE_ACTIONS = tuple(nethack.CompassDirection)

NAVIGATE_ACTIONS = MOVE_ACTIONS + (
    nethack.Command.OPEN,
    nethack.Command.KICK,
    nethack.Command.SEARCH,
)

COMBAT_ACTIONS = MOVE_ACTIONS + (
    nethack.Command.ZAP,
    nethack.Command.FIRE,
)

COMBATPICKUP_ACTIONS = MOVE_ACTIONS + (
    nethack.Command.ZAP,
    nethack.Command.FIRE,
    nethack.Command.PICKUP,
)

WOD_ACTIONS = (
    nethack.CompassDirection.E,
    nethack.Command.ZAP,
    nethack.Command.FIRE,
)

LAVACROSS_ACTIONS = MOVE_ACTIONS + (
    nethack.Command.APPLY,
    nethack.Command.WEAR,
    nethack.Command.QUAFF,
    nethack.Command.FIRE,
    nethack.Command.ZAP,
    nethack.Command.PUTON,
    nethack.Command.PICKUP,
)

QUEST_ACTIONS = MOVE_ACTIONS + (
    nethack.Command.ZAP,
    nethack.Command.FIRE,
    nethack.Command.PICKUP,
    nethack.Command.RUSH,
)

ACTION_SPACES = {
    "full": MH_FULL_ACTIONS,
    "move": MOVE_ACTIONS,
    "navigate": NAVIGATE_ACTIONS,
    "combat": COMBAT_ACTIONS,
    "combatpickup": COMBATPICKUP_ACTIONS,
    "wod": WOD_ACTIONS,
    "lavacross": LAVACROSS_ACTIONS,
    "quest": QUEST_ACTIONS,
}


def create_env(FLAGS):
    if "MultiRoom" in FLAGS.env:
        # Make a cached env wrapper
        num_threads = 3
        envs = [_create_env(FLAGS) for _ in range(num_threads)]
        return CachedEnvWrapper(envs, num_threads=num_threads)
    else:
        return _create_env(FLAGS)


def _create_env(FLAGS):
    if FLAGS.env.startswith("MiniHack"):
        kwargs = {
            "penalty_step": FLAGS.minihack.penalty_step,
            "reward_win": FLAGS.minihack.reward_win,
            "reward_lose": FLAGS.minihack.reward_lose,
        }
        if FLAGS.minihack.action_space != "default":
            kwargs["actions"] = ACTION_SPACES[FLAGS.minihack.action_space]
        if FLAGS.minihack.max_episode_steps is not None:
            kwargs["max_episode_steps"] = FLAGS.minihack.max_episode_steps

        env = gym.make(FLAGS.env, **kwargs)
    else:
        # NetHack
        nethack_kwargs = dict(
            character=FLAGS.minihack.character,
            savedir=None,
            penalty_step=FLAGS.minihack.penalty_step,
            penalty_mode="constant",
        )
        if FLAGS.minihack.max_episode_steps is not None:
            nethack_kwargs["max_episode_steps"] = FLAGS.minihack.max_episode_steps
        if FLAGS.env in ("staircase", "pet", "oracle"):
            nethack_kwargs.update(
                reward_win=FLAGS.minihack.reward_win,
                reward_lose=FLAGS.minihack.reward_lose,
            )
        else:  # print warning once
            warnings.warn("Ignoring flags.reward_win and flags.reward_lose")
        env = gym.make(FLAGS.env, **nethack_kwargs)

    if FLAGS.state_counter != "none":
        env = CounterWrapper(env, FLAGS.state_counter)
    if FLAGS.separate_message_novelty:
        env = CounterWrapper(
            env, FLAGS.separate_message_state_counter, key="state_visits_m"
        )
    env = WordWrapper(
        env,
        FLAGS.minihack.msg.word.vocab_file,
        max_message_len=FLAGS.minihack.msg.word.max_message_len,
        template=FLAGS.minihack.msg.template,
        hash_messages=FLAGS.onehot_language_goals,
    )
    return env


DESCEND_MESSAGE = np.zeros((256,), dtype=np.uint8)
for i, c in enumerate("descend"):
    DESCEND_MESSAGE[i] = ord(c)


class WordWrapper(gym.Wrapper):
    """Class used to tokenize minihack messages"""

    def __init__(
        self,
        env,
        vocab_file,
        max_message_len=25,
        max_messages=5,
        max_vocab_size=30522,
        template=True,
        remove_brackets=True,
        hash_messages=False,
    ):
        super().__init__(env)
        self.max_message_len = max_message_len
        self.max_messages = max_messages
        self.max_vocab_size = max_vocab_size
        self.template = template
        self.remove_brackets = remove_brackets
        self.hash_messages = hash_messages

        self.brackets_regex = re.compile(r"\[[^()]*\]")
        self.tokenizer = BertWordPieceTokenizer(vocab_file, lowercase=True)

        self.hash = random.randrange(10000000)  # In case lru_cache checks hash

    def __getattr__(self, name):
        return getattr(self.env, name)

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        if done and reward > 0:
            obs["message"] = DESCEND_MESSAGE
        self._tokenize_message(obs)
        return obs, reward, done, info

    def reset(self):
        obs = self.env.reset()
        self._tokenize_message(obs)
        return obs

    def _tokenize_message(self, obs):
        message_bytes = obs["message"].tobytes().rstrip(b"\x00")
        (
            obs["message"],
            obs["message_len"],
            obs["split_messages"],
            obs["split_messages_len"],
        ) = self.tokenize_messages(message_bytes)

    def tokenize_messages(self, message_bytes):
        tokens, tokens_len = self.tokenize_message(message_bytes)  # Tokenize message

        messages_split = message_bytes.split(b"  ")[
            : self.max_messages
        ]  # Tokenize split messages

        tokens_split = np.zeros(
            (self.max_messages, self.max_message_len), dtype=np.int64
        )
        tokens_len_split = np.zeros((self.max_messages, 1), dtype=np.int64)
        for message_i, message in enumerate(messages_split):
            (
                tokens_split[message_i],
                tokens_len_split[message_i],
            ) = self.tokenize_message(message)
        # Empty messages: [CLS SEP]
        tokens_split[message_i + 1 :, 0] = 101  # [CLS]
        tokens_split[message_i + 1 :, 1] = 102  # [SEP]
        tokens_len_split[message_i + 1 :, 0] = 2  # [CLS] [SEP]

        return tokens, tokens_len, tokens_split, tokens_len_split

    @lru_cache(maxsize=10000)
    def tokenize_message(self, message):

        # Some cleaning/templating
        if self.template:
            message = self.clean_message(message)

        message = message.decode("utf-8").lower()
        message = message.replace("  ", " ")
        if self.remove_brackets:
            message = re.sub(self.brackets_regex, "", message)

        if self.hash_messages:
            message = self._hash_message(message)

        tokenizer_output = self.tokenizer.encode(message)

        tokens = np.zeros(self.max_message_len, dtype=np.int64)
        for t, token_id in enumerate(tokenizer_output.ids):
            if t >= len(tokens):
                warnings.warn(f"exceeded max message len with message {message}")
                break
            tokens[t] = token_id

        tokens_len = min(len(tokens), len(tokenizer_output))
        tokens_len = np.array([tokens_len], dtype=np.int64)
        return tokens, tokens_len

    def clean_message(self, sentence):
        toks = sentence.split(b" ")
        changed_toks = []
        for tok in toks:
            if tok in MATERIALS:
                changed_toks.append(b"M")
            elif tok in MONSTERS or tok in MONSTERS_E or tok in MONSTERS_P:
                changed_toks.append(b"O")
            elif tok == b"an":
                changed_toks.append(b"a")
            elif tok == b"stones":
                changed_toks.append(b"stone")
            elif tok == b"his":
                changed_toks.append(b"her")  # Consistent pronouns
            elif tok.isdigit():
                changed_toks.append(b"N")
            elif tok == b"labeled":  # scroll labeled...
                changed_toks[-1] = changed_toks[-1] + b"."
                break
            elif tok == b"named":  # dog named...
                changed_toks[-1] = changed_toks[-1] + b"."
                break
            else:
                changed_toks.append(tok)
        changed_sentence = b" ".join(changed_toks)
        changed_sentence = changed_sentence.replace(b"wand of death", b"M wand")
        changed_sentence = changed_sentence.replace(b"wand of cold", b"M wand")
        changed_sentence = changed_sentence.replace(b"O O", b"O")
        changed_sentence = changed_sentence.replace(b"M M", b"M")
        changed_sentence = changed_sentence.replace(b"M - M", b"M")
        return changed_sentence

    def _hash_message(self, message):
        hsh = hashlib.sha1(message.encode("ascii")).hexdigest()[:10]
        return hsh

    def __hash__(self):
        return self.hash


class CounterWrapper(gym.Wrapper):
    def __init__(self, env, state_counter="none", key="state_visits"):
        # intialize state counter
        self.state_counter = state_counter
        self.key = key
        if self.state_counter != "none":
            self.state_count_dict = Counter()
        # this super() goes to the parent of the particular task, not to object
        super().__init__(env)

    def __getattr__(self, name):
        return getattr(self.env, name)

    def step(self, action):
        # add state counting to step function if desired
        step_return = self.env.step(action)
        if self.state_counter == "none":
            # do nothing
            return step_return

        obs, reward, done, info = step_return

        if self.state_counter == "ones":
            # treat every state as unique
            state_visits = 1
        elif self.state_counter == "coordinates":
            # use the location of the agent in the dungeon to accumulate visits
            features = obs["blstats"]
            x = features[0]
            y = features[1]
            d = features[12]
            coord = (d, x, y)
            self.state_count_dict[coord] += 1
            state_visits = self.state_count_dict[coord]
        elif self.state_counter == "messages":
            msg = obs["message"].tobytes()
            self.state_count_dict[msg] += 1
            state_visits = self.state_count_dict[msg]
        elif self.state_counter == "coordinates_messages":
            # Visit consists of location + message, so if message changes, you
            # get reward.
            features = obs["blstats"]
            x = features[0]
            y = features[1]
            d = features[12]
            msg = obs["message"].tobytes()
            coord = (d, x, y, msg)
            self.state_count_dict[coord] += 1
            state_visits = self.state_count_dict[coord]
        else:
            raise NotImplementedError("state_counter=%s" % self.state_counter)

        obs[self.key] = np.array([state_visits])

        if done:
            self.state_count_dict.clear()

        return step_return

    def reset(self, wizkit_items=None):
        # reset state counter when env resets
        obs = self.env.reset(wizkit_items=wizkit_items)
        if self.state_counter != "none":
            self.state_count_dict.clear()
            # current state counts as one visit
            obs[self.key] = np.array([1])
        return obs


def add_time_batch(f):
    def decorated_f(*args, **kwargs):
        output = f(*args, **kwargs)
        output_unsq = {
            k: torch.from_numpy(v) if isinstance(v, np.ndarray) else v
            for k, v in output.items()
        }
        output_unsq = {k: v.view((1, 1) + v.shape) for k, v in output_unsq.items()}
        return output_unsq

    return decorated_f


KEEP_MESSAGES = {
    "wod": {
        b"You see here a M wand.",
        b"What do you want to zap? [f or ?*]",
        b"f - a M wand.",
        b"You kill the minotaur!",
        b"You kill it!",
        b"In what direction?",
        b"Welcome to experience level 2.",
        b"You see here a minotaur corpse.",
        b"g - a minotaur corpse.",
    }
}


class TBWrapper:
    """Time-batch wrapper."""

    def __init__(self, gym_env):
        self.gym_env = gym_env
        self.initial_env_keys = list(self.gym_env.observation_space.spaces.keys()) + [
            "split_messages",
            "split_messages_len",
            "message_len",
        ]
        self.episode_return = None
        self.intrinsic_episode_step = None
        self.extrinsic_episode_step = None
        self.episode_win = None

    def __getattr__(self, item):
        return getattr(self.gym_env, item)

    def get_initial_env_state(self, env_output):
        return {k: env_output[k] for k in self.initial_env_keys}

    def to_str(self):
        obs = self.gym_env.last_observation
        tty_chars = obs[self.gym_env._observation_keys.index("tty_chars")]
        tty_colors = obs[self.gym_env._observation_keys.index("tty_colors")]
        tty_cursor = obs[self.gym_env._observation_keys.index("tty_cursor")]
        return nethack.tty_render(tty_chars, tty_colors, tty_cursor)

    @add_time_batch
    def reset(self):
        self.episode_return = 0.0
        self.intrinsic_episode_step = 0
        self.extrinsic_episode_step = 0
        self.episode_win = 0
        obs = self.gym_env.reset()
        obs["done"] = torch.tensor(True)
        obs["reward"] = torch.tensor(0.0)
        obs["episode_return"] = torch.tensor(self.episode_return)
        obs["intrinsic_episode_step"] = torch.tensor(
            self.intrinsic_episode_step, dtype=torch.int32
        )
        obs["extrinsic_episode_step"] = torch.tensor(
            self.extrinsic_episode_step, dtype=torch.int32
        )
        obs["episode_win"] = torch.tensor(self.episode_win, dtype=torch.int32)

        return obs

    @add_time_batch
    def step(self, action):
        obs, reward, done, info = self.gym_env.step(action)

        self.intrinsic_episode_step += 1
        self.extrinsic_episode_step += 1

        self.episode_return += reward

        if done and reward > 0:
            self.episode_win = 1
        else:
            self.episode_win = 0

        obs["done"] = torch.tensor(done)
        obs["reward"] = torch.tensor(reward)
        obs["episode_return"] = torch.tensor(self.episode_return)
        obs["intrinsic_episode_step"] = torch.tensor(
            self.intrinsic_episode_step, dtype=torch.int32
        )
        obs["extrinsic_episode_step"] = torch.tensor(
            self.extrinsic_episode_step, dtype=torch.int32
        )
        obs["episode_win"] = torch.tensor(self.episode_win, dtype=torch.int32)

        return obs
