from collections import defaultdict
from typing import Union

import gym
import numpy as np
from custom_minigrid.envs import COLOR_TO_IDX, OBJECT_TO_IDX
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder


def wrap_env(env, *wrappers):
    for w in wrappers:
        env = w(env)
    return env

def make_env_factory(env):
    return lambda: env

class FlatObsWrapper(gym.core.ObservationWrapper):
    """Fully observable gridworld returning a flat grid encoding."""

    def __init__(self, env):
        super().__init__(env)

        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=(self.env.width, self.env.height),  # number of cells
            dtype='uint8'
        )

    def observation(self, obs):
        env = self.unwrapped
        full_grid = env.grid.encode()
        return self.flatten_obs(full_grid, env.agent_pos, env.agent_dir)
    
    def from_observation(self, obs):
        self.unwrapped.agent_pos = np.array(np.where(obs == self.object_id("agent"))).flatten()
        return self

    def to_observation(self):
        return self.observation(self.unwrapped.gen_obs())

    def render(self, *args, **kwargs):
        kwargs['highlight'] = False
        return self.unwrapped.render(*args, **kwargs)

    @staticmethod
    def flatten_obs(grid, agent_pos, agent_dir=0):
        grid[agent_pos[0]][agent_pos[1]] = np.array([
            FlatObsWrapper.object_id("agent"),
            FlatObsWrapper.color_id("red"),
            agent_dir
        ])
        # full_grid = grid[1:-1, 1:-1]   # remove outer walls of the environment (for efficiency)
        object_grid = grid[:,:,0]
        return object_grid

    @staticmethod
    def object_id(obj):
        return OBJECT_TO_IDX[obj]

    @staticmethod
    def color_id(color):
        return COLOR_TO_IDX[color]

class AccessImageObsWrapper(gym.core.ObservationWrapper):
    """Fully observable gridworld returning a flat grid encoding."""

    def __init__(self, env, tile_size=32):
        super().__init__(env)

        self._tile_size = tile_size

        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=(3, env.unwrapped.width * self._tile_size, env.unwrapped.height * self._tile_size),
            dtype='uint8'
        )

    def observation(self, obs):
        env = self.unwrapped
        img = env.grid.render(
            self._tile_size,
            env.agent_pos,
            env.agent_dir,
            highlight_mask=None
        ).swapaxes(0, -1)
        return img

    def render(self, *args, **kwargs):
        kwargs['highlight'] = False
        return self.unwrapped.render(*args, **kwargs)

class ActionCountTracker(gym.core.Wrapper):
    
    def __init__(self, env) -> None:
        super().__init__(env)

        self.counter = defaultdict(int)

    def step(self, action):
        self.counter[action] += 1 

        obs, reward, done, info = self.env.step(action)

        info["actions_counter"] = self.counter

        return obs, reward, done, info

    def reset(self):
        self.counter = defaultdict(int)
        return self.env.reset()

class GridWorldInapplicableActionsTracker(gym.core.Wrapper):
    
    def __init__(self, env) -> None:
        super().__init__(env)

        self.inapplicable_actions = 0
        self.total_inapplicable_actions = 0
        self.inapplicable_types = defaultdict(int)
        self.inapplicable_action_types = defaultdict(lambda: defaultdict(int))

    def step(self, action):
        prev_agent_pos = np.asarray(self.unwrapped.agent_pos).copy()

        obs, reward, done, info = self.env.step(action)
        
        new_agent_pos = np.asarray(self.unwrapped.agent_pos).copy()
        if np.array_equal(prev_agent_pos, new_agent_pos) and self.env.step_count != 1:
            self.inapplicable_actions += 1
            self.total_inapplicable_actions += 1
            
            action_name = self.unwrapped.actions._value2member_map_[action].name
            fwd_pos = self.unwrapped.get_fwd_pos(action)
            fwd_cell = self.unwrapped.grid.get(*fwd_pos)
            cell_type = fwd_cell.type if fwd_cell else "empty"
            self.inapplicable_types[cell_type] += 1
            self.inapplicable_action_types[action_name][cell_type] += 1

        info["inapplicable_actions"] = self.inapplicable_actions
        info["total_inapplicable_actions"] = self.total_inapplicable_actions
        for t, v in self.inapplicable_types.items():
            info[f"inapplicable/{t}"] = v
        for a, d in self.inapplicable_action_types.items():
            for t, v in d.items():
                info[f"inapplicable/{a}/{t}"] = v

        return obs, reward, done, info

    def reset(self):
        self.inapplicable_actions = 0
        self.inapplicable_types.clear()
        self.inapplicable_action_types.clear()
        return self.env.reset()

class BadActionsTracker(gym.core.Wrapper):
    
    def __init__(self, env) -> None:
        super().__init__(env)

        self.bad_actions = 0
        self.total_bad_actions = 0

    def step(self, action):
        current_position = np.asarray(self.unwrapped.agent_pos).copy()
        current_distance_to_goal = self._compute_distance(current_position)

        obs, reward, done, info = self.env.step(action)
        
        new_agent_pos = np.asarray(self.unwrapped.agent_pos).copy()
        new_distance_to_goal = self._compute_distance(new_agent_pos)
        if new_distance_to_goal >= current_distance_to_goal:
            self.bad_actions += 1
            self.total_bad_actions += 1

        info["bad_actions"] = self.bad_actions
        info["total_bad_actions"] = self.total_bad_actions

        return obs, reward, done, info

    def _compute_distance(self, agent_pos):
        # XXXXXXX
        # XG....X
        # XXXXX.X
        # X.....X
        # X.XXXXX
        # X....@X
        # XXXXXXX
        
        distance = 0

        width = self.unwrapped.width - 2
        height = self.unwrapped.height - 2

        # Inverse problem
        # agent_pos = (self.unwrapped.height - agent_pos[0] - 1, self.unwrapped.width - agent_pos[1] - 1)

        for r in range(height, 0, -1):
            if agent_pos[1] >= r: break
            if r % 2 == 0:
                distance += 1
            else:
                distance +=  width

        distance += (width - agent_pos[0]) if (agent_pos[1] % 4 in (1, 2)) else (agent_pos[0] - 1)

        return distance

    def reset(self):
        self.bad_actions = 0
        return self.env.reset()

class VideoRecorder(VecVideoRecorder):

    def __init__(
        self, 
        env: Union[DummyVecEnv, gym.Env], 
        video_folder: str, 
        start_recording_step: int = 1000, 
        video_length: int = 200, 
        name_prefix: str = "rl-video"
    ):
        if not isinstance(env, DummyVecEnv):
            venv = DummyVecEnv([lambda: Monitor(env)])
        else:
            venv = env

        record_video_trigger = lambda x: x % start_recording_step == 0

        super().__init__(venv, video_folder, record_video_trigger, video_length, name_prefix)
