from typing import Generic, TypeVar, cast

import gymnasium as gym
import numpy as np

SInType = TypeVar("SInType")
ActType = TypeVar("ActType")


class RGBObservationWrapper(
	gym.ObservationWrapper[dict, ActType, SInType], Generic[SInType, ActType]
):
	"""Wrapper to add RGB observation to the environment.

	The environment being wrapped must have a render method that returns
	an RGB image.

	Args:
		env (gym.Env): The environment to wrap.
		width (int): Width of the RGB image.
		height (int): Height of the RGB image.
	"""

	def __init__(
		self,
		env: gym.Env[SInType, ActType],
		width: int = 64,
		height: int = 64,
		rate: int = 4,
	) -> None:
		super().__init__(env)
		self.width = width
		self.height = height
		self.rate = rate

		original_space = env.observation_space
		rgb_space = gym.spaces.Box(
			low=0, high=255, shape=(height * width * 3,), dtype=np.uint8
		)

		self.observation_space = gym.spaces.Dict(
			{"state": original_space, "rgb": rgb_space}
		)
		self.last_frame = np.zeros(
			(height, width, 3), dtype=np.uint8
		).flatten()
		self.count = 0

	def observation(self, observation: SInType) -> dict:
		"""Returns the observation with RGB frame.

		Args:
			observation (SInType): The observation outputted by the
				underlying environment.

		Returns:
			dict: The observation containing both the original state and
				the RGB frame.
		"""
		self.count += 1
		if self.count % self.rate == 0:
			rgb_frame = self.env.render()
			if rgb_frame is not None:
				self.last_frame = cast(np.ndarray, rgb_frame).flatten()
			self.count = 0

		return {"state": observation, "rgb": self.last_frame}
