import typing
from typing import Optional, Any, Dict

import gym
import gym_minigrid.minigrid
import numpy as np
from gym_minigrid.minigrid import MiniGridEnv

from rl_base.sensor import Sensor
from rl_base.task import Task, SubTaskType


class EgocentricMiniGridSensor(Sensor[MiniGridEnv, Task[MiniGridEnv]]):
    def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
        super().__init__(config, *args, **kwargs)
        self.agent_view_size = config["agent_view_size"]
        self.view_channels = config.get("view_channels", 1)
        self.num_objects = (
            typing.cast(
                int, max(map(abs, gym_minigrid.minigrid.OBJECT_TO_IDX.values()))
            )
            + 1
        )
        self.num_colors = (
            typing.cast(int, max(map(abs, gym_minigrid.minigrid.COLOR_TO_IDX.values())))
            + 1
        )
        self.num_states = (
            typing.cast(int, max(map(abs, gym_minigrid.minigrid.STATE_TO_IDX.values())))
            + 1
        )
        self.observation_space = self._get_observation_space()

    def _get_uuid(self, *args: Any, **kwargs: Any) -> str:
        return "minigrid_ego_image"

    def _get_observation_space(self) -> gym.Space:
        return gym.spaces.Box(
            low=0,
            high=max(self.num_objects, self.num_colors, self.num_states) - 1,
            shape=(self.agent_view_size, self.agent_view_size, self.view_channels),
            dtype=int,
        )

    def get_observation(
        self,
        env: MiniGridEnv,
        task: Optional[SubTaskType],
        *args,
        minigrid_output_obs: Optional[np.ndarray] = None,
        **kwargs: Any
    ) -> Any:
        if minigrid_output_obs is not None and minigrid_output_obs["image"].shape == (
            self.agent_view_size,
            self.agent_view_size,
        ):
            img = minigrid_output_obs["image"][:, :, : self.view_channels]
        else:
            env.agent_view_size = self.agent_view_size
            img = env.gen_obs()["image"][:, :, : self.view_channels]

        assert img.dtype == np.uint8
        return img
