from collections import Counter

import gym
import gym_minigrid.wrappers as wrappers
import numpy as np
import torch
from babyai.levels.verifier import INSTRS
from gym_minigrid import minigrid
from gym_minigrid.minigrid import COLOR_TO_IDX, OBJECT_TO_IDX


def relative_coords_from_grid(abs_goal_x, abs_goal_y, grid, partial_grid):
    """Compute relative coords, but get agent positional info etc from grid"""
    assert grid.ndim == 3
    assert partial_grid.ndim == 3
    agent_pos = (grid[:, :, 0] == 10).nonzero()
    assert agent_pos.shape == (1, 2)  # Should be one coord

    agent_pos = (agent_pos[0, 0].item(), agent_pos[0, 1].item())
    agent_dir = grid[agent_pos][2].item()

    dx = abs_goal_x - agent_pos[0]
    dy = abs_goal_y - agent_pos[1]

    for _ in range(agent_dir + 1):
        # clockwise rotation
        (dx, dy) = (dy, -dx)

    # Now return coords relative to egocentric view, where agent pos is always
    # last row, middle
    agent_view_size = partial_grid.shape[0]
    partial_agent_pos = (agent_view_size // 2, agent_view_size - 1)

    rel_goal_x = dx + partial_agent_pos[0]
    rel_goal_y = dy + partial_agent_pos[1]

    if (
        rel_goal_x < 0
        or rel_goal_y < 0
        or rel_goal_x >= agent_view_size
        or rel_goal_y >= agent_view_size
    ):
        return None

    return (rel_goal_x, rel_goal_y)


class CounterWrapper(gym.Wrapper):
    def __init__(self, env, state_counter="none", key="state_visits"):
        # intialize state counter
        self.state_counter = state_counter
        self.key = key
        if self.state_counter != "none":
            self.state_count_dict = Counter()
        # this super() goes to the parent of the particular task, not to object
        super().__init__(env)

    def __getattr__(self, name):
        return getattr(self.env, name)

    def step(self, action):
        # add state counting to step function if desired
        step_return = self.env.step(action)

        obs, reward, done, info = step_return

        if self.state_counter == "none":
            # treat every state as unique
            state_visits = 1
        elif self.state_counter == "coordinates":
            # use BabyAI frames
            frame_key = tuple(obs["image"].ravel().tolist())
            self.state_count_dict[frame_key] += 1
            state_visits = self.state_count_dict[frame_key]
        elif self.state_counter == "messages":
            msg_key = tuple(obs["subgoal_done"].ravel().tolist())
            self.state_count_dict[msg_key] += 1
            state_visits = self.state_count_dict[msg_key]
        elif self.state_counter == "coordinates_messages":
            cm_key = (
                *obs["image"].ravel().tolist(),
                *obs["subgoal_done"].ravel().tolist(),
            )
            self.state_count_dict[cm_key] += 1
            state_visits = self.state_count_dict[cm_key]
        else:
            raise NotImplementedError("state_counter=%s" % self.state_counter)

        obs[self.key] = np.array([state_visits])

        if done:
            self.state_count_dict.clear()

        return step_return

    def reset(self):
        # reset state counter when env resets
        obs = self.env.reset()
        if self.state_counter != "none":
            self.state_count_dict.clear()
            # current state counts as one visit
            obs[self.key] = np.array([1])
        return obs


def create_env(FLAGS):
    env = gym.make(FLAGS.env)
    env = FullyObsWrapperWithSubgoals(env)
    env = Minigrid2ImageWithSubgoals(env)
    env = CounterWrapper(env, FLAGS.state_counter)
    if FLAGS.separate_message_novelty:
        env = CounterWrapper(
            env, FLAGS.separate_message_state_counter, key="state_visits_m"
        )
    return env


# Helper functions and wrappers
def _format_observation(obs_dict):
    obs = torch.tensor(obs_dict["image"])
    obs = obs.view((1, 1) + obs.shape)  # (...) -> (T,B,...).
    obs_processed = {
        "image": obs,
        "partial_image": torch.tensor(obs_dict["partial_image"]).view(
            (1, 1) + obs_dict["partial_image"].shape
        ),
        "state_visits": torch.tensor(obs_dict["state_visits"]),
    }
    if "subgoal_done" in obs_dict:
        obs_processed["subgoal_done"] = torch.from_numpy(obs_dict["subgoal_done"])
    else:
        obs_processed["subgoal_done"] = torch.zeros(len(INSTRS), dtype=torch.uint8)
    if "subgoal_achievable" in obs_dict:
        obs_processed["subgoal_achievable"] = torch.from_numpy(
            obs_dict["subgoal_achievable"]
        )
    else:
        obs_processed["subgoal_achievable"] = torch.zeros(
            len(INSTRS), dtype=torch.uint8
        )
    return obs_processed


class Minigrid2Image(gym.ObservationWrapper):
    """Get MiniGrid observation to ignore language instruction."""

    def __init__(self, env):
        gym.ObservationWrapper.__init__(self, env)
        self.observation_space = env.observation_space.spaces["image"]
        self.partial_observation_space = env.partial_observation_space

    def observation(self, observation):
        return observation["image"]


class Minigrid2ImageWithSubgoals(Minigrid2Image):
    """MiniGrid observation which also returns"""

    def observation(self, obs):
        obs_processed = {
            "image": obs["image"],
            "partial_image": obs["partial_image"],
        }
        if "subgoal_done" in obs:
            obs_processed["subgoal_done"] = obs["subgoal_done"]
        if "subgoal_achievable" in obs:
            obs_processed["subgoal_achievable"] = obs["subgoal_achievable"]
        return obs_processed


class FullyObsWrapperWithSubgoals(wrappers.FullyObsWrapper):
    """
    Fully observable gridworld using a compact grid encoding that passes
    subgoals through
    """

    def __init__(self, env):
        self.partial_observation_space = env.observation_space["image"]
        super().__init__(env)

    def observation(self, obs):
        env = self.unwrapped
        full_grid = env.grid.encode()
        full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array(
            [
                minigrid.OBJECT_TO_IDX["agent"],
                minigrid.COLOR_TO_IDX["red"],
                env.agent_dir,
            ]
        )

        obs_processed = {
            "mission": obs["mission"],
            "image": full_grid,
            "partial_image": obs["image"],
        }
        if "subgoal_done" in obs:
            obs_processed["subgoal_done"] = obs["subgoal_done"]
        if "subgoal_achievable" in obs:
            obs_processed["subgoal_achievable"] = obs["subgoal_achievable"]
        return obs_processed


class Observation_WrapperSetup:
    """Environment wrapper to format observation items into torch."""

    def __init__(self, gym_env, reset_when_done=False):
        self.gym_env = gym_env
        self.episode_return = None
        self.intrinsic_episode_step = None
        self.extrinsic_episode_step = None
        self.episode_win = None
        self.reset_when_done = reset_when_done

    def get_initial_env_state(self, env_output):
        """Get the initial env state relevant for the teacher"""
        return {"frame": env_output["frame"]}

    def reset(self):
        initial_reward = torch.zeros(1, 1)
        self.episode_return = torch.zeros(1, 1)
        self.intrinsic_episode_step = torch.zeros(1, 1, dtype=torch.int32)
        self.extrinsic_episode_step = torch.zeros(1, 1, dtype=torch.int32)
        self.episode_win = torch.zeros(1, 1, dtype=torch.int32)
        initial_done = torch.ones(1, 1, dtype=torch.uint8)
        env_output = _format_observation(self.gym_env.reset())
        initial_frame = env_output["image"]
        initial_partial_frame = env_output["partial_image"]
        initial_subgoal_done = env_output["subgoal_done"]
        initial_subgoal_achievable = env_output["subgoal_achievable"]
        state_visits = env_output["state_visits"]

        if self.gym_env.carrying:
            carried_col, carried_obj = torch.LongTensor(
                [[COLOR_TO_IDX[self.gym_env.carrying.color]]]
            ), torch.LongTensor([[OBJECT_TO_IDX[self.gym_env.carrying.type]]])
        else:
            carried_col, carried_obj = torch.LongTensor([[5]]), torch.LongTensor([[1]])

        return dict(
            frame=initial_frame,
            partial_frame=initial_partial_frame,
            subgoal_done=initial_subgoal_done,
            subgoal_achievable=initial_subgoal_achievable,
            reward=initial_reward,
            done=initial_done,
            state_visits=state_visits,
            episode_return=self.episode_return,
            intrinsic_episode_step=self.intrinsic_episode_step,
            extrinsic_episode_step=self.extrinsic_episode_step,
            episode_win=self.episode_win,
            carried_col=carried_col,
            carried_obj=carried_obj,
        )

    def step(self, action):
        obs, reward, done, _ = self.gym_env.step(action.item())

        self.intrinsic_episode_step += 1
        self.extrinsic_episode_step += 1
        intrinsic_episode_step = self.intrinsic_episode_step
        extrinsic_episode_step = self.extrinsic_episode_step

        self.episode_return += reward
        episode_return = self.episode_return

        if done and reward > 0:
            self.episode_win[0][0] = 1
        else:
            self.episode_win[0][0] = 0
        episode_win = self.episode_win

        if done and self.reset_when_done:
            obs = self.gym_env.reset()
            self.episode_return = torch.zeros(1, 1)
            self.intrinsic_episode_step = torch.zeros(1, 1, dtype=torch.int32)
            self.extrinsic_episode_step = torch.zeros(1, 1, dtype=torch.int32)
            self.episode_win = torch.zeros(1, 1, dtype=torch.int32)

        env_output = _format_observation(obs)
        frame = env_output["image"]
        partial_frame = env_output["partial_image"]
        subgoal_done = env_output["subgoal_done"]
        subgoal_achievable = env_output["subgoal_achievable"]
        state_visits = env_output["state_visits"]

        reward = torch.tensor(reward).view(1, 1)
        done = torch.tensor(done).view(1, 1)

        if self.gym_env.carrying:
            carried_col, carried_obj = torch.LongTensor(
                [[COLOR_TO_IDX[self.gym_env.carrying.color]]]
            ), torch.LongTensor([[OBJECT_TO_IDX[self.gym_env.carrying.type]]])
        else:
            carried_col, carried_obj = torch.LongTensor([[5]]), torch.LongTensor([[1]])

        return dict(
            frame=frame,
            partial_frame=partial_frame,
            subgoal_done=subgoal_done,
            subgoal_achievable=subgoal_achievable,
            reward=reward,
            done=done,
            state_visits=state_visits,
            episode_return=episode_return,
            intrinsic_episode_step=intrinsic_episode_step,
            extrinsic_episode_step=extrinsic_episode_step,
            episode_win=episode_win,
            carried_col=carried_col,
            carried_obj=carried_obj,
        )

    def close(self):
        self.gym_env.close()


def frame_to_str(frame, goal=None):
    """
    Produce a pretty string of the environment's grid along with the agent.
    A grid cell is represented by 2-character string, the first one for
    the object and the second one for the color.
    """
    if isinstance(frame, torch.Tensor):
        frame = frame.cpu().numpy()

    # Map of object types to short string
    OBJECT_TO_STR = {
        "wall": "W",
        "floor": "F",
        "door": "D",
        "key": "K",
        "ball": "A",
        "box": "B",
        "goal": "G",
        "lava": "V",
        "unseen": ".",
    }

    goal_channel = np.zeros(frame.shape[0] * frame.shape[1], dtype=np.bool)
    if goal is not None:
        if isinstance(goal, torch.Tensor):
            goal = goal.item()
        goal_channel[goal] = True
    goal_channel = goal_channel.reshape((frame.shape[0], frame.shape[1]))

    # Map agent's direction to short string
    AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}

    grid_str = ""

    for h in range(frame.shape[0]):  # Height
        for w in range(frame.shape[1]):  # Width
            obj_type, obj_color, obj_feat = frame[w, h]
            obj_color = minigrid.IDX_TO_COLOR[obj_color]
            if obj_type == 10:  # Agent location
                c = AGENT_DIR_TO_STR[obj_feat] * 2
            elif obj_type == 11:  # Special indicator for debugging, etc
                c = "##"
            elif obj_type == 1:  # Empty
                c = "  "
            else:
                otype = minigrid.IDX_TO_OBJECT[obj_type]
                c = OBJECT_TO_STR[otype]

                if otype == "door":
                    if obj_feat == 0:  # Open
                        c = "__"
                    elif obj_feat == 1:  # Closed
                        c = "D" + obj_color[0].upper()
                    elif obj_feat == 2:  # Locked
                        c = "L" + obj_color[0].upper()
                    else:
                        raise RuntimeError(f"Unknown obj state {obj_feat}")
                elif otype == "unseen":
                    c = ".."
                else:
                    c += obj_color[0].upper()

            if goal_channel[w, h]:
                c = "@" + c[1]
            grid_str += c

        if h < frame.shape[0] - 1:
            grid_str += "\n"

    return grid_str
