from typing import Optional, Tuple

import gymnasium as gym
import numpy as np

from src.async_rl.amdp import RGBObservationWrapper, Signal
from src.async_rl.amdp import wrapper as awrap


def make_halfcheetah_rgbenv(
	action_history: bool = False,
	period_range: Optional[Tuple[float, float]] = None,
	gamma_beta_range: Optional[Tuple[float, float]] = None,
	flatten_obs: bool = True,
	remember: bool = True,
	rgb_rate: int = 4,
) -> gym.Env:
	"""Creates an asynchronous HalfCheetah environment with RGB.

	The environment is defined so that to each hinge joint corresponds a
	signal. The average periodicity of every signal is `1` time step. On
	top of that, an additional signal represents the rendered visual RGB
	frame of the environment.

	Args:
		action_history (bool, optional): Whether to include the action
			history in the observation. Defaults to `False`.
		period_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' renewal periods at each
			rollout. Defaults to `None`, which means the default period
			of each signal is used.
		gamma_beta_range (Optional[Tuple[float, float]], optional): The
			range from which to sample signals' shape parameter at each
			rollout. Defaults to `None`, which means the default shape
			parameter of each signal is used.
		flatten_obs (bool, optional): Whether to flatten the observation
			dictionary before returning it. Defaults to `True`. Notice
			that this does not discards the individual observations of
			each signal, but rather includes the flattened aggreagation
			in a new key called `"observation"`.
		remember (bool, optional): Whether to remember the most recent
			observations of unseen signals, together with the time since
			they were last seen. Defaults to `True`. If `False`, the
			observations of unseen signals are set to `0` and the last
			coordinate in the observation is used as a flag to indicate
			whether the signal was seen in the last time step.
		rgb_rate (int, optional): The rate at which to render the RGB
			frame. Defaults to `4`, meaning that the RGB frame is
			updated every `4` time steps.

	Returns:
		gym.Env: The asynchronous HalfCheetah environment.
	"""
	env = gym.make(
		"HalfCheetah-v4", render_mode="rgb_array", width=64, height=64
	)
	env = RGBObservationWrapper(env, width=64, height=64, rate=rgb_rate)

	front_tip: Signal[dict, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(5,)),
		lambda obs: np.concatenate([obs["state"][:2], obs["state"][8:11]]),
		max_interval=4.0,
		name="front_tip",
	)
	back_thigh: Signal[dict, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs["state"][2], obs["state"][11]]),
		max_interval=4.0,
		name="back_thigh",
	)
	back_shin: Signal[dict, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs["state"][3], obs["state"][12]]),
		max_interval=4.0,
		name="back_shin",
	)
	back_foot: Signal[dict, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs["state"][4], obs["state"][13]]),
		max_interval=4.0,
		name="back_foot",
	)
	front_thigh: Signal[dict, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs["state"][5], obs["state"][14]]),
		max_interval=4.0,
		name="front_thigh",
	)
	front_shin: Signal[dict, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs["state"][6], obs["state"][15]]),
		max_interval=4.0,
		name="front_shin",
	)
	front_foot: Signal[dict, np.ndarray] = Signal(
		gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(2,)),
		lambda obs: np.array([obs["state"][7], obs["state"][16]]),
		max_interval=4.0,
		name="front_foot",
	)

	rgb_frame: Signal[dict, np.ndarray] = Signal(
		gym.spaces.Box(low=0, high=255, shape=(12_288,), dtype=np.uint8),
		lambda obs: obs["rgb"],
		max_interval=4.0,
		name="rgb_frame",
	)

	env = awrap.SignaledObservation(
		env,
		[
			front_tip,
			back_thigh,
			back_shin,
			back_foot,
			front_thigh,
			front_shin,
			front_foot,
			rgb_frame,
		],
	)
	env = awrap.AsynchronousSignals(
		env,
		action_history=action_history,
		period_range=period_range,
		gamma_beta_range=gamma_beta_range,
		remember=remember,
		include_flattened_obs=flatten_obs,
	)

	return env
