from typing import Optional, Any, Dict

import gym
import numpy as np

from extensions.rl_poisoneddoors.poisoneddoors_tasks import (
    PoisonedDoorsEnvironment,
    PoisonedDoorsTask,
    PoisonedEnvStates,
)
from rl_base.sensor import Sensor


class PoisonedDoorCurrentStateSensor(
    Sensor[PoisonedDoorsEnvironment, PoisonedDoorsTask]
):
    def __init__(self, config: Dict[str, Any], *args: Any, **kwargs: Any):
        super().__init__(config, *args, **kwargs)

        self.nstates = len(PoisonedEnvStates)
        self.observation_space = self._get_observation_space()

    def _get_uuid(self, *args: Any, **kwargs: Any) -> str:
        return "poisoned_door_state"

    def _get_observation_space(self) -> gym.Space:
        return gym.spaces.Box(low=0, high=self.nstates - 1, shape=(1,), dtype=int,)

    def get_observation(
        self,
        env: PoisonedDoorsEnvironment,
        task: Optional[PoisonedDoorsTask],
        *args,
        minigrid_output_obs: Optional[np.ndarray] = None,
        **kwargs: Any
    ) -> Any:
        return np.array([int(env.current_state.value)])
