"""gym wrappers for MDP & monitor environments"""
import numpy as np
import gymnasium as gym


class WindowViewObs(gym.ObservationWrapper):
    """Observation wrapper get a window view around the agent"""

    def __init__(
            self, env, window_size: int,
            agent_channel: int = 0,
            agent_id: int = 1,
            image_obs: bool = False,
            flatten_obs: bool = False,
    ):
        super().__init__(env)
        self.env = env
        old_shape = self.env.observation_space.shape
        if window_size > old_shape[1] or window_size > old_shape[2]:
            raise ValueError("The window size should smaller than half of the grid size")
        self.window_size = window_size
        self.agent_channel = agent_channel
        self.agent_id = agent_id
        self.image_obs = image_obs
        self.flatten_obs = flatten_obs

        high = 255 if image_obs else int(env.observation_space.high_repr)
        low = 0 if image_obs else int(env.observation_space.low_repr)

        if len(self.env.observation_space.shape) != 3:
            raise ValueError("Parent env should have 3 dimensions CxHxW")
        self.observation_space = gym.spaces.Box(low, high, (old_shape[0], window_size, window_size), dtype=np.int8)

    def observation(self, obs):
        last_x, last_y = self.env.observation_space.shape[1], self.env.observation_space.shape[2]
        pos_x = np.where(obs[self.agent_channel, :, :] == self.agent_id)[0][0]
        pos_y = np.where(obs[self.agent_channel, :, :] == self.agent_id)[1][0]
        min_x, max_x = max(0, pos_x - self.window_size // 2), min(last_x, pos_x + self.window_size // 2 + 1)
        min_y, max_y = max(0, pos_y - self.window_size // 2), min(last_y, pos_y + self.window_size // 2 + 1)
        window_obs = obs[:, min_x:max_x, min_y:max_y]
        if window_obs.shape[1] != self.window_size or window_obs.shape[2] != self.window_size:
            raise ValueError("mismatch shape of observation in window wrapper")
        return window_obs.reshape(-1) if self.flatten_obs else window_obs


class WallObs(gym.ObservationWrapper):
    def __init__(self, env, grid_size: tuple, n_walls: int):
        super().__init__(env)
        self.env = env
        self.n_walls = n_walls
        self.grid_size = grid_size
        new_shape = list(self.env.observation_space.shape)
        new_shape[0] += 1
        new_shape[1] += 2 * self.n_walls
        new_shape[2] += 2 * self.n_walls
        self.shape = tuple(new_shape)
        self.observation_space = gym.spaces.Box(0, self.shape[0], shape=self.shape, dtype=np.uint8)

    def observation(self, obs):
        new_obs = np.zeros(self.shape)
        # copy obs
        new_obs[:-1, self.n_walls : -self.n_walls, self.n_walls : -self.n_walls] = obs.copy()

        # walls channel
        new_obs[-1, : self.n_walls, :] = 1
        new_obs[-1, :, : self.n_walls] = 1
        new_obs[-1, -self.n_walls :, :] = 1
        new_obs[-1, :, -self.n_walls :] = 1
        return new_obs

    def get_n_walls(self):
        return self.n_walls


class MultiChannel(gym.ObservationWrapper):
    """Convert observations to RGB images"""
    def __init__(self, env, normalize_obs: bool = False):
        super().__init__(env)
        self.env = env
        self.normalize_obs = normalize_obs
        shape = self.env.observation_space.shape
        self.observation_space = gym.spaces.Box(0, 1, shape=(shape[0] + 2, shape[1], shape[2]), dtype=np.uint8)

    def observation(self, obs):
        new_obs = np.zeros((3, obs.shape[1], obs.shape[2]))
        for i in range(obs[1].shape[0]):
            for j in range(obs[1].shape[1]):
                value = obs[1, i, j]
                if value == 1.0:  # plant
                    div = 2 if self.normalize_obs else 1
                    new_obs[:, i, j] = np.array([1, 1, 0]) / div
                elif value == 0.5:  # cactus
                    div = 2 if self.normalize_obs else 1
                    new_obs[:, i, j] = np.array([0, 1, 1]) / div
                elif value == 0.125:  # different
                    new_obs[:, i, j] = np.array([0, 0, 1])
                elif value == 0.25:  # different
                    new_obs[:, i, j] = np.array([0, 1, 0])
                elif value == 0.375:  # different
                    new_obs[:, i, j] = np.array([1, 0, 0])
                elif value == 0.625:  # different
                    div = 2 if self.normalize_obs else 1
                    new_obs[:, i, j] = np.array([1, 0, 1]) / div
                elif value == 0.875:  # different
                    div = 3 if self.normalize_obs else 1
                    new_obs[:, i, j] = np.array([1, 1, 1]) / div
                else:
                    continue
        return np.concatenate(
            (np.expand_dims(obs[0], 0), new_obs, np.expand_dims(obs[2], 0), np.expand_dims(obs[3], 0)), 0)
